This commit is contained in:
leewlving 2024-06-21 10:12:04 +08:00
parent 9539ba5aed
commit 5b84447a3d
8 changed files with 292 additions and 51 deletions

View File

@ -9,6 +9,10 @@ import random
from PIL import Image from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.simple_tokenizer import SimpleTokenizer as Tokenizer from model.simple_tokenizer import SimpleTokenizer as Tokenizer
import re
import os
import numpy as np
from torchvision import transforms
@ -62,7 +66,28 @@ class BaseDataset(Dataset):
return image return image
def pre_caption(self, caption):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
caption_words = caption.split(' ')
if len(caption_words)>self.maxWords:
caption = ' '.join(caption_words[:self.maxWords])
return caption
def _load_text(self, index: int): def _load_text(self, index: int):
# print(self.captions[index])
# captions = self.pre_caption(self.captions[index])
# captions =whitespace_clean(self.captions[index]).lower()
captions = self.captions[index] captions = self.captions[index]
use_cap = captions[random.randint(0, len(captions) - 1)] use_cap = captions[random.randint(0, len(captions) - 1)]
@ -79,7 +104,7 @@ class BaseDataset(Dataset):
caption.append(0) caption.append(0)
caption = torch.tensor(caption) caption = torch.tensor(caption)
return caption return captions
def _load_label(self, index: int) -> torch.Tensor: def _load_label(self, index: int) -> torch.Tensor:
label = self.labels[index] label = self.labels[index]
@ -101,3 +126,175 @@ class BaseDataset(Dataset):
return image, caption, label, index return image, caption, label, index
# def default_loader(path):
# return Image.open(path).convert('RGB')
# class IaprDataset(Dataset):
# def __init__(self, args,txt,transform=None, loader=default_loader):
# self.transform = transform
# self.loader = loader
# name_label = []
# for line in open(txt):
# line = line.strip('\n').split()
# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的前2912个是单词在词袋中的二进制编码
# tem = re.split('[/.]', line[0])
# file_name, sample_name = tem[0], tem[1]
# name_label.append([file_name, sample_name, label])
# # # print('label = ', label)
# # print('file_name = %s, sample_name = %s' %(file_name, sample_name))
# # label_list = np.where(label=='1')
# # print('label_list = ', label_list)
# self.name_label = name_label
# self.image_dir=args.image_dir
# self.text_dir = args.text_dir
# def __getitem__(self, index):
# words = self.name_label[index] # words = [file_name, sample_name, label]
# # print('words = ', words[0:2])
# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg')
# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt')
# # img
# img = self.loader(img_path)
# if self.transform is not None:
# img = self.transform(img)
# # text
# text = 'None'
# for line in open(text_path):
# text = '[CLS]' + line + '[SEP]'
# # label
# label = torch.LongTensor(words[2])
# # image, caption, label, index
# return img, text, label, index
# def __len__(self):
# return len(self.name_label)
default_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class MScocoDataset(Dataset):
def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None):
self.data_path = data_path
if transform is None:
self.transform = default_transform
else:
self.transform = transform
img_filepath = os.path.join(data_path, img_filename)
with open(img_filepath, 'r') as f:
self.imgs = [x.strip() for x in f]
text_filepath = os.path.join(data_path, text_filename)
with open(text_filepath, 'r') as f:
self.texts = f.readlines()
self.texts = [i.replace('\n', '') for i in self.texts]
label_filepath = os.path.join(data_path, label_filename)
self.labels = np.genfromtxt(label_filepath, dtype=np.int32)
def __getitem__(self, index):
img = Image.open(os.path.join(self.data_path, self.imgs[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.from_numpy(self.labels[index]).float()
text = self.texts[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(self.imgs)
class MirflickrDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080")
file_list = os.listdir(file_path)
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
self.class_index = {}
for i, item in enumerate(file_list):
self.class_index.update({item: i})
self.label_dict = {}
for path_id in file_list:
path = os.path.join(file_path, path_id)
with open(path, "r") as f:
for item in f:
item = item.strip()
if item not in self.label_dict:
label = np.zeros(len(file_list))
label[self.class_index[path_id]] = 1
self.label_dict.update({item: label})
else:
# print()
self.label_dict[item][self.class_index[path_id]] = 1
self.captions_dict = {}
captions_path = os.path.join(root_dir, "mirflickr/meta/tags")
captions_list = os.listdir(captions_path)
for item in captions_list:
id_ = item.split(".")[0].replace("tags", "")
caption = ""
with open(os.path.join(captions_path, item), "r") as f:
for word in f.readlines():
caption += word.strip() + " "
caption = caption.strip()
self.captions_dict.update({id_: caption})
if transform is None:
self.transform = default_transform
else:
self.transform = transform
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))
class NusWideDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))

View File

@ -1,8 +1,12 @@
from .base import BaseDataset from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset
import os import os
import numpy as np import numpy as np
import scipy.io as scio import scipy.io as scio
import torch
import torchvision.transforms as transforms
import json
from torch.utils.data import Dataset
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None): def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
@ -45,8 +49,13 @@ def dataloader(captionFile: str,
with open(captionFile, "r") as f: with open(captionFile, "r") as f:
captions = f.readlines() captions = f.readlines()
captions = np.asarray([[item.strip()] for item in captions]) captions = np.asarray([[item.strip()] for item in captions])
elif captionFile.endswith("json"):
with open(captionFile, "r") as f:
data = json.load(f)
captions=data["caption"]
# captions = captions[0] if captions.shape[0] == 1 else captions
else: else:
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.") raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.")
if not npy: if not npy:
indexs = scio.loadmat(indexFile)["index"] indexs = scio.loadmat(indexFile)["index"]
else: else:
@ -64,5 +73,56 @@ def dataloader(captionFile: str,
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
return train_data, query_data, retrieval_data return train_data, query_data, retrieval_data
def get_dataset_filename(split):
filename = {
'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'),
'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'),
'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt')
}
return filename[split]
def cross_modal_dataset(args ):
transform_test = transforms.Compose([
transforms.Resize((args.resolution, args.resolution)),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
if args.dataset == 'flickr25k':
dataset=MirflickrDataset(root_dir=args.flickr25k_root)
elif args.dataset == 'coco':
img_name, text_name, label_name = get_dataset_filename('train')
dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root,
text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test)
elif args.dataset == 'iapr':
train_file = os.path.join(args.iapr_root, 'iapr_train')
test_file = os.path.join(args.root, 'iapr_test')
retrieval_file = os.path.join(args.root, 'iapr_retrieval')
train_set=IaprDataset(args,txt=train_file, transform=transform_test)
test_set =IaprDataset(args,txt=test_file, transform=transform_test)
db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test)
train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
else:
raise ValueError("Not support.")
# if args.dataset == 'iapr':
# elif args.dataset == 'coco':
# img_name, text_name, label_name = get_dataset_filename('train')
# pass
# else:
# test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75])
# train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35])
# train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
# test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
# db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
return train_loader, test_loader, db_loader , train_set, test_set, db_set

View File

@ -76,7 +76,6 @@ captions = {"caption": captions}
scio.savemat(os.path.join(root_dir, "mat/index.mat"), index) scio.savemat(os.path.join(root_dir, "mat/index.mat"), index)
with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f: with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f:
json.dump(captions, f, ensure_ascii=False) json.dump(captions, f, ensure_ascii=False)
# scio.savemat(os.path.join(root_dir, "mat/caption.mat"), captions)
scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels) scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels)

View File

@ -4,7 +4,7 @@ from train.text_train import Trainer
if __name__ == "__main__": if __name__ == "__main__":
engine=Trainer() engine=Trainer()
# engine.test() engine.test()
engine.train_epoch() engine.train_epoch()

View File

@ -9,6 +9,8 @@ from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm from tqdm import tqdm
from .model import build_model from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer from .simple_tokenizer import SimpleTokenizer as _Tokenizer

View File

@ -15,9 +15,11 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari
from utils.calc_utils import cal_map, cal_pr from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader from dataset.dataloader import dataloader
import clip import clip
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
# from transformers import BertModel # from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# tokenizer=Tokenizer()
def clamp(delta, clean_imgs): def clamp(delta, clean_imgs):
@ -95,6 +97,8 @@ class Trainer(TrainBase):
label_train=[] label_train=[]
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
text=text.to(device, non_blocking=True) text=text.to(device, non_blocking=True)
text=clip.tokenize(text)
# print(self.model.vocab_size) # print(self.model.vocab_size)
temp_text=self.model.encode_text(text) temp_text=self.model.encode_text(text)
text_train.append(temp_text.cpu().detach().numpy()) text_train.append(temp_text.cpu().detach().numpy())
@ -113,10 +117,10 @@ class Trainer(TrainBase):
return text_representation, text_var_representation return text_representation, text_var_representation
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var
,epsilon=0.03125, alpha=3/255, num_iter=1500): ,epsilon=0.03125, alpha=3/255):
delta = torch.zeros_like(image,requires_grad=True) delta = torch.zeros_like(image,requires_grad=True)
for i in range(num_iter): for i in range(self.args.epochs):
self.model.zero_grad() self.model.zero_grad()
anchor=self.model.encode_image(image+delta) anchor=self.model.encode_image(image+delta)
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity()) loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())

View File

@ -13,13 +13,16 @@ from .base import TrainBase
from torch.nn import functional as F from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import cal_map, cal_pr from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader from dataset.dataloader import cross_modal_dataset
import clip import clip
import copy import copy
from model.bert_tokenizer import BertTokenizer from model.bert_tokenizer import BertTokenizer
from model.simple_tokenizer import SimpleTokenizer as Tokenizer from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
# from transformers import BertModel # from transformers import BertModel
import ftfy
import regex as re
import html
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -51,6 +54,17 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·'] '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words) filter_words = set(filter_words)
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def get_bpe_substitues(substitutes, tokenizer, mlm_model): def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# substitutes L, k # substitutes L, k
# device = mlm_model.device # device = mlm_model.device
@ -139,46 +153,10 @@ class Trainer(TrainBase):
def _init_dataset(self): def _init_dataset(self):
self.logger.info("init dataset.") self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.") self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) train_loader, test_loader, db_loader , train_set, test_set, db_set=cross_modal_dataset(self.args)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.train_loader=train_loader
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) self.test_loader=test_loader
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, self.retrieval_loader=db_loader
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.train_data=train_data
def generate_mapping(self): def generate_mapping(self):
image_train=[] image_train=[]
@ -186,6 +164,7 @@ class Trainer(TrainBase):
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
image=image.to(device, non_blocking=True) image=image.to(device, non_blocking=True)
# print(self.model.vocab_size) # print(self.model.vocab_size)
text = self.clip_tokenizer.tokenize(text)
temp_image=self.model.encode_image(image) temp_image=self.model.encode_image(image)
image_train.append(temp_image.cpu().detach().numpy()) image_train.append(temp_image.cpu().detach().numpy())
label_train.append(label.detach().numpy()) label_train.append(label.detach().numpy())

View File

@ -9,7 +9,7 @@ def get_args():
parser.add_argument("--save-dir", type=str, default="./result/64-bit") parser.add_argument("--save-dir", type=str, default="./result/64-bit")
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
parser.add_argument("--pretrained", type=str, default="") parser.add_argument("--pretrained", type=str, default="")
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") parser.add_argument("--dataset", type=str, default="iapr", help="choise from [coco, flckr25k, iapr]")
parser.add_argument("--index-file", type=str, default="index.mat") parser.add_argument("--index-file", type=str, default="index.mat")
parser.add_argument("--caption-file", type=str, default="caption.mat") parser.add_argument("--caption-file", type=str, default="caption.mat")
parser.add_argument("--label-file", type=str, default="label.mat") parser.add_argument("--label-file", type=str, default="label.mat")
@ -21,7 +21,7 @@ def get_args():
parser.add_argument("--beta", type=float, default=10.0) parser.add_argument("--beta", type=float, default=10.0)
parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512) parser.add_argument("--output-dim", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=500)
parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)