diff --git a/dataset/base.py b/dataset/base.py index c489de4..7d405f4 100644 --- a/dataset/base.py +++ b/dataset/base.py @@ -9,6 +9,10 @@ import random from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from model.simple_tokenizer import SimpleTokenizer as Tokenizer +import re +import os +import numpy as np +from torchvision import transforms @@ -62,7 +66,28 @@ class BaseDataset(Dataset): return image + def pre_caption(self, caption): + caption = re.sub( + r"([,.'!?\"()*#:;~])", + '', + caption.lower(), + ).replace('-', ' ').replace('/', ' ').replace('', 'person') + caption = re.sub( + r"\s{2,}", + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + caption_words = caption.split(' ') + if len(caption_words)>self.maxWords: + caption = ' '.join(caption_words[:self.maxWords]) + return caption + def _load_text(self, index: int): + # print(self.captions[index]) + # captions = self.pre_caption(self.captions[index]) + # captions =whitespace_clean(self.captions[index]).lower() captions = self.captions[index] use_cap = captions[random.randint(0, len(captions) - 1)] @@ -79,7 +104,7 @@ class BaseDataset(Dataset): caption.append(0) caption = torch.tensor(caption) - return caption + return captions def _load_label(self, index: int) -> torch.Tensor: label = self.labels[index] @@ -101,3 +126,175 @@ class BaseDataset(Dataset): return image, caption, label, index +# def default_loader(path): +# return Image.open(path).convert('RGB') + +# class IaprDataset(Dataset): +# def __init__(self, args,txt,transform=None, loader=default_loader): +# self.transform = transform +# self.loader = loader + +# name_label = [] +# for line in open(txt): +# line = line.strip('\n').split() +# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的,前2912个是单词在词袋中的二进制编码 +# tem = re.split('[/.]', line[0]) +# file_name, sample_name = tem[0], tem[1] +# name_label.append([file_name, sample_name, label]) +# # # print('label = ', label) +# # print('file_name = %s, sample_name = %s' %(file_name, sample_name)) +# # label_list = np.where(label=='1') +# # print('label_list = ', label_list) +# self.name_label = name_label +# self.image_dir=args.image_dir +# self.text_dir = args.text_dir + + +# def __getitem__(self, index): +# words = self.name_label[index] # words = [file_name, sample_name, label] +# # print('words = ', words[0:2]) + +# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg') +# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt') +# # img +# img = self.loader(img_path) +# if self.transform is not None: +# img = self.transform(img) +# # text +# text = 'None' +# for line in open(text_path): +# text = '[CLS]' + line + '[SEP]' + +# # label +# label = torch.LongTensor(words[2]) +# # image, caption, label, index +# return img, text, label, index + + +# def __len__(self): +# return len(self.name_label) + +default_transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + +class MScocoDataset(Dataset): + def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None): + self.data_path = data_path + + if transform is None: + self.transform = default_transform + else: + self.transform = transform + + img_filepath = os.path.join(data_path, img_filename) + with open(img_filepath, 'r') as f: + self.imgs = [x.strip() for x in f] + + text_filepath = os.path.join(data_path, text_filename) + with open(text_filepath, 'r') as f: + self.texts = f.readlines() + self.texts = [i.replace('\n', '') for i in self.texts] + + label_filepath = os.path.join(data_path, label_filename) + self.labels = np.genfromtxt(label_filepath, dtype=np.int32) + + def __getitem__(self, index): + img = Image.open(os.path.join(self.data_path, self.imgs[index])) + img = img.convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + label = torch.from_numpy(self.labels[index]).float() + text = self.texts[index] +# image, caption, label, index + + return img, text, label, index + + def __len__(self): + return len(self.imgs) + +class MirflickrDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080") + file_list = os.listdir(file_path) + file_list = [item for item in file_list if "_r1" not in item and "README" not in item] + self.class_index = {} + for i, item in enumerate(file_list): + self.class_index.update({item: i}) + self.label_dict = {} + for path_id in file_list: + path = os.path.join(file_path, path_id) + with open(path, "r") as f: + for item in f: + item = item.strip() + if item not in self.label_dict: + label = np.zeros(len(file_list)) + label[self.class_index[path_id]] = 1 + self.label_dict.update({item: label}) + else: + # print() + self.label_dict[item][self.class_index[path_id]] = 1 + self.captions_dict = {} + captions_path = os.path.join(root_dir, "mirflickr/meta/tags") + captions_list = os.listdir(captions_path) + for item in captions_list: + id_ = item.split(".")[0].replace("tags", "") + caption = "" + with open(os.path.join(captions_path, item), "r") as f: + for word in f.readlines(): + caption += word.strip() + " " + caption = caption.strip() + self.captions_dict.update({id_: caption}) + if transform is None: + self.transform = default_transform + else: + self.transform = transform + + def __getitem__(self, index): + label=self.label_dict[index] + PATH = os.path.join(self.root_dir, "mirflickr") + img=Image.open(os.path.join(PATH, "im" + index + ".jpg")) + img = img.convert('RGB') + if self.transform is not None: + img = self.transform(img) + + # label = torch.from_numpy(self.labels[index]).float() + text = self.captions_dict[index] +# image, caption, label, index + + return img, text, label, index + + def __len__(self): + return len(list(self.label_dict.keys())) + + +class NusWideDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt") + labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels") + textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt") + + + def __getitem__(self, index): + label=self.label_dict[index] + PATH = os.path.join(self.root_dir, "mirflickr") + img=Image.open(os.path.join(PATH, "im" + index + ".jpg")) + img = img.convert('RGB') + if self.transform is not None: + img = self.transform(img) + + # label = torch.from_numpy(self.labels[index]).float() + text = self.captions_dict[index] +# image, caption, label, index + + return img, text, label, index + + def __len__(self): + return len(list(self.label_dict.keys())) \ No newline at end of file diff --git a/dataset/dataloader.py b/dataset/dataloader.py index 2ab4ac7..43aa436 100644 --- a/dataset/dataloader.py +++ b/dataset/dataloader.py @@ -1,8 +1,12 @@ -from .base import BaseDataset +from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset import os import numpy as np import scipy.io as scio +import torch +import torchvision.transforms as transforms +import json +from torch.utils.data import Dataset def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None): @@ -45,8 +49,13 @@ def dataloader(captionFile: str, with open(captionFile, "r") as f: captions = f.readlines() captions = np.asarray([[item.strip()] for item in captions]) + elif captionFile.endswith("json"): + with open(captionFile, "r") as f: + data = json.load(f) + captions=data["caption"] + # captions = captions[0] if captions.shape[0] == 1 else captions else: - raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.") + raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.") if not npy: indexs = scio.loadmat(indexFile)["index"] else: @@ -64,5 +73,56 @@ def dataloader(captionFile: str, retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy) return train_data, query_data, retrieval_data - + +def get_dataset_filename(split): + filename = { + 'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'), + 'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'), + 'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt') + } + + return filename[split] + +def cross_modal_dataset(args ): + transform_test = transforms.Compose([ + transforms.Resize((args.resolution, args.resolution)), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ]) + if args.dataset == 'flickr25k': + dataset=MirflickrDataset(root_dir=args.flickr25k_root) + elif args.dataset == 'coco': + img_name, text_name, label_name = get_dataset_filename('train') + + dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root, + text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test) + elif args.dataset == 'iapr': + train_file = os.path.join(args.iapr_root, 'iapr_train') + test_file = os.path.join(args.root, 'iapr_test') + retrieval_file = os.path.join(args.root, 'iapr_retrieval') + train_set=IaprDataset(args,txt=train_file, transform=transform_test) + test_set =IaprDataset(args,txt=test_file, transform=transform_test) + db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test) + train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4) + test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4) + db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4) + else: + raise ValueError("Not support.") + + # if args.dataset == 'iapr': + + # elif args.dataset == 'coco': + # img_name, text_name, label_name = get_dataset_filename('train') + + # pass + # else: + # test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75]) + # train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35]) + # train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4) + # test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4) + # db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4) + + return train_loader, test_loader, db_loader , train_set, test_set, db_set + + diff --git a/dataset/make_mirflickr25k.py b/dataset/make_mirflickr25k.py index f3e0a68..b3e7b23 100644 --- a/dataset/make_mirflickr25k.py +++ b/dataset/make_mirflickr25k.py @@ -76,7 +76,6 @@ captions = {"caption": captions} scio.savemat(os.path.join(root_dir, "mat/index.mat"), index) with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f: json.dump(captions, f, ensure_ascii=False) -# scio.savemat(os.path.join(root_dir, "mat/caption.mat"), captions) scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels) diff --git a/main.py b/main.py index e8d63d6..c68339f 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from train.text_train import Trainer if __name__ == "__main__": engine=Trainer() - # engine.test() + engine.test() engine.train_epoch() diff --git a/model/clip.py b/model/clip.py index 6ce5565..d66ff6f 100755 --- a/model/clip.py +++ b/model/clip.py @@ -9,6 +9,8 @@ from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm + + from .model import build_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer diff --git a/train/hash_train.py b/train/hash_train.py index 7bab447..6d4b31c 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -15,9 +15,11 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari from utils.calc_utils import cal_map, cal_pr from dataset.dataloader import dataloader import clip +from model.simple_tokenizer import SimpleTokenizer as Tokenizer # from transformers import BertModel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# tokenizer=Tokenizer() def clamp(delta, clean_imgs): @@ -95,6 +97,8 @@ class Trainer(TrainBase): label_train=[] for image, text, label, index in self.train_loader: text=text.to(device, non_blocking=True) + + text=clip.tokenize(text) # print(self.model.vocab_size) temp_text=self.model.encode_text(text) text_train.append(temp_text.cpu().detach().numpy()) @@ -113,10 +117,10 @@ class Trainer(TrainBase): return text_representation, text_var_representation def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var - ,epsilon=0.03125, alpha=3/255, num_iter=1500): + ,epsilon=0.03125, alpha=3/255): delta = torch.zeros_like(image,requires_grad=True) - for i in range(num_iter): + for i in range(self.args.epochs): self.model.zero_grad() anchor=self.model.encode_image(image+delta) loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity()) diff --git a/train/text_train.py b/train/text_train.py index 6313c0e..a36c611 100644 --- a/train/text_train.py +++ b/train/text_train.py @@ -13,13 +13,16 @@ from .base import TrainBase from torch.nn import functional as F from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils.calc_utils import cal_map, cal_pr -from dataset.dataloader import dataloader +from dataset.dataloader import cross_modal_dataset import clip import copy from model.bert_tokenizer import BertTokenizer from model.simple_tokenizer import SimpleTokenizer as Tokenizer from transformers import BertForMaskedLM # from transformers import BertModel +import ftfy +import regex as re +import html device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -51,6 +54,17 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] filter_words = set(filter_words) +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + def get_bpe_substitues(substitutes, tokenizer, mlm_model): # substitutes L, k # device = mlm_model.device @@ -139,46 +153,10 @@ class Trainer(TrainBase): def _init_dataset(self): self.logger.info("init dataset.") self.logger.info(f"Using {self.args.dataset} dataset.") - self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) - self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) - self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) - train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, - indexFile=self.args.index_file, - labelFile=self.args.label_file, - maxWords=self.args.max_words, - imageResolution=self.args.resolution, - query_num=self.args.query_num, - train_num=self.args.train_num, - seed=self.args.seed) - self.train_labels = train_data.get_all_label() - self.query_labels = query_data.get_all_label() - self.retrieval_labels = retrieval_data.get_all_label() - self.args.retrieval_num = len(self.retrieval_labels) - self.logger.info(f"query shape: {self.query_labels.shape}") - self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") - self.train_loader = DataLoader( - dataset=train_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.query_loader = DataLoader( - dataset=query_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.retrieval_loader = DataLoader( - dataset=retrieval_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.train_data=train_data - + train_loader, test_loader, db_loader , train_set, test_set, db_set=cross_modal_dataset(self.args) + self.train_loader=train_loader + self.test_loader=test_loader + self.retrieval_loader=db_loader def generate_mapping(self): image_train=[] @@ -186,6 +164,7 @@ class Trainer(TrainBase): for image, text, label, index in self.train_loader: image=image.to(device, non_blocking=True) # print(self.model.vocab_size) + text = self.clip_tokenizer.tokenize(text) temp_image=self.model.encode_image(image) image_train.append(temp_image.cpu().detach().numpy()) label_train.append(label.detach().numpy()) diff --git a/utils/get_args.py b/utils/get_args.py index 6322a07..ba40894 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -9,7 +9,7 @@ def get_args(): parser.add_argument("--save-dir", type=str, default="./result/64-bit") parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") parser.add_argument("--pretrained", type=str, default="") - parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") + parser.add_argument("--dataset", type=str, default="iapr", help="choise from [coco, flckr25k, iapr]") parser.add_argument("--index-file", type=str, default="index.mat") parser.add_argument("--caption-file", type=str, default="caption.mat") parser.add_argument("--label-file", type=str, default="label.mat") @@ -21,7 +21,7 @@ def get_args(): parser.add_argument("--beta", type=float, default=10.0) parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--output-dim", type=int, default=512) - parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--epochs", type=int, default=500) parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--batch-size", type=int, default=8)