深度學習總結——用自己的數據集微調CLIP

CLIP概述

CLIP(Contrastive Language-Image Pretraining)是由OpenAI開發的一種深度學習模型,用于將圖像和自然語言文本進行聯合編碼。它采用了多模態學習的方法,使得模型能夠理解圖像和文本之間的語義關系。

它的核心思想是將圖像和文本視為同等重要的輸入,并通過聯合訓練來學習它們之間的聯系。CLIP模型使用了一個共享的編碼器,它將圖像和文本分別映射到一個共享的特征空間中。通過將圖像和文本的編碼向量進行比較,模型能夠判斷它們之間的相似性和相關性。

它在訓練過程中使用了對比損失函數,以鼓勵模型將相關的圖像和文本對編碼得更接近,而將不相關的圖像和文本對編碼得更遠。這使得CLIP模型能夠具有良好的泛化能力,能夠在訓練過程中學習到通用的圖像和文本理解能力。

它的整體流程如下:
在這里插入圖片描述

它展現了強大的zero-shot能力,在許多視覺與語言任務中表現出色,如圖像分類、圖像生成描述、圖像問答等。它的多模態能力使得CLIP模型能夠在圖像和文本之間建立強大的語義聯系,為各種應用場景提供了更全面的理解和分析能力。

正是因為它出色的zero-shot能力,因此訓練的模型本身就含有很多可以利用的知識,因此在一些任務上,如分類任務,caption任務,可以嘗試在自己的數據集上微調CLIP,或許通過這種操作就能獲得不錯的性能。但是目前如何微調CLIP網上并沒有看到很詳細的介紹,因此我整理了相關的知識并在此記錄。
參考鏈接

微調代碼

第三方庫

  • clip-by-openai
  • torch

下面以我做的圖像分類任務為例,介紹相關的步驟。

步驟介紹

1.構建數據集

構建自己的數據集,每次迭代返回的數據包括:RGB圖像和圖像的標簽(a photo of {label})
代碼示例如下:

import os
from PIL import Image
import numpy as np
import clip
class YourDataset(Dataset):def __init__(self,img_root,meta_root,is_train,preprocess):# 1.根目錄(根據自己的情況更改)self.img_root = img_rootself.meta_root = meta_root# 2.訓練圖片和測試圖片地址(根據自己的情況更改)self.train_set_file = os.path.join(meta_root,'train.txt')self.test_set_file = os.path.join(meta_root,'test.txt')# 3.訓練 or 測試(根據自己的情況更改)self.is_train = is_train# 4.處理圖像self.img_process = preprocess# 5.獲得數據(根據自己的情況更改)self.samples = []self.sam_labels = []# 5.1 訓練還是測試數據集self.read_file = ""if is_train:self.read_file = self.train_set_fileelse:self.read_file = self.test_set_file# 5.2 獲得所有的樣本(根據自己的情況更改)with open(self.read_file,'r') as f:for line in f:img_path = os.path.join(self.img_root,line.strip() + '.jpg')label = line.strip().split('/')[0]label = label.replace("_"," ")label = "a photo of " + labelself.samples.append(img_path)self.sam_labels.append(label)# 轉換為tokenself.tokens = clip.tokenize(self.sam_labels)def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path = self.samples[idx]token = self.tokens[idx]# 加載圖像image = Image.open(img_path).convert('RGB')# 對圖像進行轉換image = self.img_process(image)return image,token

2.加載預訓練CLIP模型和相關配置

首先使用第三方庫加載預訓練的CLIP模型,會返回一個CLIP模型和一個圖像預處理函數preprocess,這將用于之后的數據加載過程。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)

然后初始化優化器,損失函數,需要注意的是,如果剛開始你的損失很大或者出現異常,可以調整優化器的學習率和其他參數來進行調整,通常是調整的更小會有效果。

optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 創建損失函數
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

3.加載數據

該步驟主要是調用第一步中創建的類,然后使用DataLoader函數加載自己的數據集。
代碼如下:

your_dataset = YourDataset(img_root= '/images',meta_root= '/meta',is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)

4.開始訓練

訓練代碼按照模板來寫即可,總共要訓練epoches次,每次要將一個數據集里面的所有數據都訓練一次,然后在每次訓練完成的時候保存模型,這里分為兩種:

  • 保存模型的參數
  • 保存模型的參數、優化器、迭代次數

該部分的代碼如下:

phase = "train"
model_name = "your model name"
ckt_gap = 4
epoches = 30
for epoch in range(epoches):scheduler.step()total_loss = 0batch_num = 0# 使用混合精度,占用顯存更小with torch.cuda.amp.autocast(enabled=True):for images,label_tokens in your_dataloader:# 將圖片和標簽token轉移到device設備images = images.to(device)label_tokens = label_tokens.to(device)batch_num += 1# 優化器梯度清零optimizer.zero_grad()with torch.set_grad_enabled(phase == "train"):logits_per_image, logits_per_text = net(images, label_tokens)ground_truth = torch.arange(len(images),dtype=torch.long,device=device)cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2total_loss += cur_lossif phase == "train":cur_loss.backward()if device == "cpu":optimizer.step()else:optimizer.step()clip.model.convert_weights(net) if batch_num % 4 == 0:logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))epoch_loss = total_loss / dataset_size_yourtorch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")logger.info(f"weights_{epoch} saved")if epoch % ckt_gap == 0:checkpoint_path = f"{model_name}_ckt.pth"checkpoint = {'it': epoch,'network': net.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict()}torch.save(checkpoint, checkpoint_path)logger.info(f"checkpoint_{epoch} saved")logger.info('{} Loss: {:.4f}'.format(phase, epoch_loss))

全部代碼

import os
from PIL import Image
import numpy as np
import clip
from loguru import logger
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nnclass YourDataset(Dataset):def __init__(self,img_root,meta_root,is_train,preprocess):# 1.根目錄(根據自己的情況更改)self.img_root = img_rootself.meta_root = meta_root# 2.訓練圖片和測試圖片地址(根據自己的情況更改)self.train_set_file = os.path.join(meta_root,'train.txt')self.test_set_file = os.path.join(meta_root,'test.txt')# 3.訓練 or 測試(根據自己的情況更改)self.is_train = is_train# 4.處理圖像self.img_process = preprocess# 5.獲得數據(根據自己的情況更改)self.samples = []self.sam_labels = []# 5.1 訓練還是測試數據集self.read_file = ""if is_train:self.read_file = self.train_set_fileelse:self.read_file = self.test_set_file# 5.2 獲得所有的樣本(根據自己的情況更改)with open(self.read_file,'r') as f:for line in f:img_path = os.path.join(self.img_root,line.strip() + '.jpg')label = line.strip().split('/')[0]label = label.replace("_"," ")label = "photo if " + labelself.samples.append(img_path)self.sam_labels.append(label)# 轉換為tokenself.tokens = clip.tokenize(self.sam_labels)def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path = self.samples[idx]token = self.tokens[idx]# 加載圖像image = Image.open(img_path).convert('RGB')# 對圖像進行轉換image = self.img_process(image)return image,token
# 創建模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, preprocess = clip.load("RN50",device=device,jit=False)optimizer = optim.Adam(net.parameters(), lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 創建損失函數
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# 加載數據集
your_dataset = YourDataset(img_root= '/images',meta_root= '/meta',is_train=True,preprocess=preprocess)
dataset_size_your = len(your_dataset)
your_dataloader = DataLoader(your_dataset,batch_size=4,shuffle=True,num_workers=4,pin_memory=False)phase = "train"
model_name = "your model name"
ckt_gap = 4
for epoch in range(st,args.epoches):scheduler.step()total_loss = 0batch_num = 0# 使用混合精度,占用顯存更小with torch.cuda.amp.autocast(enabled=True):for images,label_tokens in your_dataloader:# 將圖片和標簽token轉移到device設備images = images.to(device)label_tokens = label_tokens.to(device)batch_num += 1# 優化器梯度清零optimizer.zero_grad()with torch.set_grad_enabled(phase == "train"):logits_per_image, logits_per_text = net(images, label_tokens)ground_truth = torch.arange(len(images),dtype=torch.long,device=device)cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2total_loss += cur_lossif phase == "train":cur_loss.backward()if device == "cpu":optimizer.step()else:optimizer.step()clip.model.convert_weights(net) if batch_num % 4 == 0:logger.info('{} epoch:{} loss:{}'.format(phase,epoch,cur_loss))epoch_loss = total_loss / dataset_size_food101torch.save(net.state_dict(),f"{model_name}_epoch_{epoch}.pth")logger.info(f"weights_{epoch} saved")if epoch % ckt_gap == 0:checkpoint_path = f"{model_name}_ckt.pth"checkpoint = {'it': epoch,'network': net.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict()}torch.save(checkpoint, checkpoint_path)logger.info(f"checkpoint_{epoch} saved")logger.info('{} Loss: {:.4f}'.format(phase, epoch_loss))

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。如若轉載,請注明出處:https://dhexx.cn/hk/4627913.html

如若內容造成侵權/違法違規/事實不符,請聯系我的編程經驗分享網進行投訴反饋,一經查實,立即刪除!


相關文章:

  • 阻塞隊列和生產者-消費者模式
  • 硬鏈接與符號鏈接
  • cesium vue教程目錄導航
  • 【走進Linux的世界】Linux---基本指令(2)
  • JavaScript藍橋杯------學海無涯
  • JavaScript基本語法(二)
  • 輔助駕駛功能開發-功能規范篇(25)-1-全景影像AVM規范
  • Python3內置關鍵字大全
  • C++ RapidJSON使用詳解
  • 【CSDN如何獲得鐵粉】
  • 什么情形下應該使用BFF?帶你了解BFF的優勢,即服務于前端的后端
  • 【讀書筆記】《平凡的世界》- 路遙
  • 在線程中執行任務
  • 90. Python列表推導式
  • 【C/C++】基礎知識之動態申請內存空間new-delete
  • 網絡支付存在的風險有什么
  • 在編程中,代理、委托、回調、鉤子、句柄、打樁的區別
  • 實戰Windows Chrome 0day
  • 阻塞方法與中斷方法
  • Vue-CLI + Vue3 + Vue-Router4 實現tabbar小案例
  • 【C++】 作用域(::)和命名空間(namespace)使用的注意事項
  • 文件與文件系統的打包、壓縮、備份
  • 工業相機丟包排查步驟
  • LinkedBlockingQueue阻塞隊列
  • JDK8新特性,記錄常用的知識點
  • 【極海APM32F4xx Tiny】學習筆記04-移植FreeRTOS
  • JavaSE常用API
  • mysql數據類型有哪幾種
  • xxjob代碼執行過程
  • k8s 配置service失敗
  • 第1節:vue cesium 概述(含網站地址+視頻)
  • 【001 Linux內核】內核鏡像格式有幾種?分別有什么區別?
  • 地震segy數據高效讀寫庫cigsegy在windows系統的編譯
  • NginxFoundation
  • 【HISI IC萌新虛擬項目】Package Process Unit項目全流程目錄
  • 前后端分離項目之登錄頁面(前后端請求、響應和連接數據庫)
  • 20230603----重返學習-react路由導航-路由傳參-react-router-dom的V6版本
  • 《Python編程從入門到實踐》學習筆記06字典
  • 數據庫四種事務隔離級別的區別以及可能出現的問題
  • ClassLoader源碼
  • AIGC技術研究與應用 ---- 下一代人工智能:新范式!新生產力!(2.3-大模型發展歷程 之 多模態)
  • 信息論與編碼 SCUEC DDDD 期末復習
  • 一文搞懂編程界中最基礎最常見【必知必會】的十一個算法,再也別說你只是聽說過【建議收藏+關注】
  • Linux 實驗三 Linux C開發工具的使用
  • 基于 Linux 通信架構的 Thread Pool A 線程池分析
  • python homework完成回合擲骰子游戲
  • 恒容容器放氣的瞬時流量的計算與合金氫化物放氫流量曲線的計算
  • solr快速上手:配置從mysql同步數據(五)
  • 【深入淺出Spring原理及實戰】「夯實基礎系列」360全方位滲透和探究SpringMVC的核心原理和運作機制(總體框架原理篇)
  • 「SQL面試題庫」 No_87 學生們參加各科測試的次數