add
This commit is contained in:
parent
9539ba5aed
commit
5b84447a3d
199
dataset/base.py
199
dataset/base.py
|
|
@ -9,6 +9,10 @@ import random
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -62,7 +66,28 @@ class BaseDataset(Dataset):
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
def pre_caption(self, caption):
|
||||||
|
caption = re.sub(
|
||||||
|
r"([,.'!?\"()*#:;~])",
|
||||||
|
'',
|
||||||
|
caption.lower(),
|
||||||
|
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
||||||
|
caption = re.sub(
|
||||||
|
r"\s{2,}",
|
||||||
|
' ',
|
||||||
|
caption,
|
||||||
|
)
|
||||||
|
caption = caption.rstrip('\n')
|
||||||
|
caption = caption.strip(' ')
|
||||||
|
caption_words = caption.split(' ')
|
||||||
|
if len(caption_words)>self.maxWords:
|
||||||
|
caption = ' '.join(caption_words[:self.maxWords])
|
||||||
|
return caption
|
||||||
|
|
||||||
def _load_text(self, index: int):
|
def _load_text(self, index: int):
|
||||||
|
# print(self.captions[index])
|
||||||
|
# captions = self.pre_caption(self.captions[index])
|
||||||
|
# captions =whitespace_clean(self.captions[index]).lower()
|
||||||
captions = self.captions[index]
|
captions = self.captions[index]
|
||||||
use_cap = captions[random.randint(0, len(captions) - 1)]
|
use_cap = captions[random.randint(0, len(captions) - 1)]
|
||||||
|
|
||||||
|
|
@ -79,7 +104,7 @@ class BaseDataset(Dataset):
|
||||||
caption.append(0)
|
caption.append(0)
|
||||||
caption = torch.tensor(caption)
|
caption = torch.tensor(caption)
|
||||||
|
|
||||||
return caption
|
return captions
|
||||||
|
|
||||||
def _load_label(self, index: int) -> torch.Tensor:
|
def _load_label(self, index: int) -> torch.Tensor:
|
||||||
label = self.labels[index]
|
label = self.labels[index]
|
||||||
|
|
@ -101,3 +126,175 @@ class BaseDataset(Dataset):
|
||||||
|
|
||||||
return image, caption, label, index
|
return image, caption, label, index
|
||||||
|
|
||||||
|
# def default_loader(path):
|
||||||
|
# return Image.open(path).convert('RGB')
|
||||||
|
|
||||||
|
# class IaprDataset(Dataset):
|
||||||
|
# def __init__(self, args,txt,transform=None, loader=default_loader):
|
||||||
|
# self.transform = transform
|
||||||
|
# self.loader = loader
|
||||||
|
|
||||||
|
# name_label = []
|
||||||
|
# for line in open(txt):
|
||||||
|
# line = line.strip('\n').split()
|
||||||
|
# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的,前2912个是单词在词袋中的二进制编码
|
||||||
|
# tem = re.split('[/.]', line[0])
|
||||||
|
# file_name, sample_name = tem[0], tem[1]
|
||||||
|
# name_label.append([file_name, sample_name, label])
|
||||||
|
# # # print('label = ', label)
|
||||||
|
# # print('file_name = %s, sample_name = %s' %(file_name, sample_name))
|
||||||
|
# # label_list = np.where(label=='1')
|
||||||
|
# # print('label_list = ', label_list)
|
||||||
|
# self.name_label = name_label
|
||||||
|
# self.image_dir=args.image_dir
|
||||||
|
# self.text_dir = args.text_dir
|
||||||
|
|
||||||
|
|
||||||
|
# def __getitem__(self, index):
|
||||||
|
# words = self.name_label[index] # words = [file_name, sample_name, label]
|
||||||
|
# # print('words = ', words[0:2])
|
||||||
|
|
||||||
|
# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg')
|
||||||
|
# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt')
|
||||||
|
# # img
|
||||||
|
# img = self.loader(img_path)
|
||||||
|
# if self.transform is not None:
|
||||||
|
# img = self.transform(img)
|
||||||
|
# # text
|
||||||
|
# text = 'None'
|
||||||
|
# for line in open(text_path):
|
||||||
|
# text = '[CLS]' + line + '[SEP]'
|
||||||
|
|
||||||
|
# # label
|
||||||
|
# label = torch.LongTensor(words[2])
|
||||||
|
# # image, caption, label, index
|
||||||
|
# return img, text, label, index
|
||||||
|
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return len(self.name_label)
|
||||||
|
|
||||||
|
default_transform = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop(224),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
class MScocoDataset(Dataset):
|
||||||
|
def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None):
|
||||||
|
self.data_path = data_path
|
||||||
|
|
||||||
|
if transform is None:
|
||||||
|
self.transform = default_transform
|
||||||
|
else:
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
img_filepath = os.path.join(data_path, img_filename)
|
||||||
|
with open(img_filepath, 'r') as f:
|
||||||
|
self.imgs = [x.strip() for x in f]
|
||||||
|
|
||||||
|
text_filepath = os.path.join(data_path, text_filename)
|
||||||
|
with open(text_filepath, 'r') as f:
|
||||||
|
self.texts = f.readlines()
|
||||||
|
self.texts = [i.replace('\n', '') for i in self.texts]
|
||||||
|
|
||||||
|
label_filepath = os.path.join(data_path, label_filename)
|
||||||
|
self.labels = np.genfromtxt(label_filepath, dtype=np.int32)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img = Image.open(os.path.join(self.data_path, self.imgs[index]))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
label = torch.from_numpy(self.labels[index]).float()
|
||||||
|
text = self.texts[index]
|
||||||
|
# image, caption, label, index
|
||||||
|
|
||||||
|
return img, text, label, index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.imgs)
|
||||||
|
|
||||||
|
class MirflickrDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, transform=None):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080")
|
||||||
|
file_list = os.listdir(file_path)
|
||||||
|
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
|
||||||
|
self.class_index = {}
|
||||||
|
for i, item in enumerate(file_list):
|
||||||
|
self.class_index.update({item: i})
|
||||||
|
self.label_dict = {}
|
||||||
|
for path_id in file_list:
|
||||||
|
path = os.path.join(file_path, path_id)
|
||||||
|
with open(path, "r") as f:
|
||||||
|
for item in f:
|
||||||
|
item = item.strip()
|
||||||
|
if item not in self.label_dict:
|
||||||
|
label = np.zeros(len(file_list))
|
||||||
|
label[self.class_index[path_id]] = 1
|
||||||
|
self.label_dict.update({item: label})
|
||||||
|
else:
|
||||||
|
# print()
|
||||||
|
self.label_dict[item][self.class_index[path_id]] = 1
|
||||||
|
self.captions_dict = {}
|
||||||
|
captions_path = os.path.join(root_dir, "mirflickr/meta/tags")
|
||||||
|
captions_list = os.listdir(captions_path)
|
||||||
|
for item in captions_list:
|
||||||
|
id_ = item.split(".")[0].replace("tags", "")
|
||||||
|
caption = ""
|
||||||
|
with open(os.path.join(captions_path, item), "r") as f:
|
||||||
|
for word in f.readlines():
|
||||||
|
caption += word.strip() + " "
|
||||||
|
caption = caption.strip()
|
||||||
|
self.captions_dict.update({id_: caption})
|
||||||
|
if transform is None:
|
||||||
|
self.transform = default_transform
|
||||||
|
else:
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
label=self.label_dict[index]
|
||||||
|
PATH = os.path.join(self.root_dir, "mirflickr")
|
||||||
|
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
# label = torch.from_numpy(self.labels[index]).float()
|
||||||
|
text = self.captions_dict[index]
|
||||||
|
# image, caption, label, index
|
||||||
|
|
||||||
|
return img, text, label, index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(list(self.label_dict.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
class NusWideDataset(Dataset):
|
||||||
|
def __init__(self, root_dir, transform=None):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
|
||||||
|
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
|
||||||
|
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
label=self.label_dict[index]
|
||||||
|
PATH = os.path.join(self.root_dir, "mirflickr")
|
||||||
|
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
|
||||||
|
# label = torch.from_numpy(self.labels[index]).float()
|
||||||
|
text = self.captions_dict[index]
|
||||||
|
# image, caption, label, index
|
||||||
|
|
||||||
|
return img, text, label, index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(list(self.label_dict.keys()))
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
|
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io as scio
|
import scipy.io as scio
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import json
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
|
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
|
||||||
|
|
@ -45,8 +49,13 @@ def dataloader(captionFile: str,
|
||||||
with open(captionFile, "r") as f:
|
with open(captionFile, "r") as f:
|
||||||
captions = f.readlines()
|
captions = f.readlines()
|
||||||
captions = np.asarray([[item.strip()] for item in captions])
|
captions = np.asarray([[item.strip()] for item in captions])
|
||||||
|
elif captionFile.endswith("json"):
|
||||||
|
with open(captionFile, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
captions=data["caption"]
|
||||||
|
# captions = captions[0] if captions.shape[0] == 1 else captions
|
||||||
else:
|
else:
|
||||||
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.")
|
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.")
|
||||||
if not npy:
|
if not npy:
|
||||||
indexs = scio.loadmat(indexFile)["index"]
|
indexs = scio.loadmat(indexFile)["index"]
|
||||||
else:
|
else:
|
||||||
|
|
@ -64,5 +73,56 @@ def dataloader(captionFile: str,
|
||||||
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
retrieval_data = BaseDataset(captions=split_captions[2], indexs=split_indexs[2], labels=split_labels[2], maxWords=maxWords, imageResolution=imageResolution, is_train=False, npy=npy)
|
||||||
|
|
||||||
return train_data, query_data, retrieval_data
|
return train_data, query_data, retrieval_data
|
||||||
|
|
||||||
|
def get_dataset_filename(split):
|
||||||
|
filename = {
|
||||||
|
'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'),
|
||||||
|
'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'),
|
||||||
|
'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt')
|
||||||
|
}
|
||||||
|
|
||||||
|
return filename[split]
|
||||||
|
|
||||||
|
def cross_modal_dataset(args ):
|
||||||
|
transform_test = transforms.Compose([
|
||||||
|
transforms.Resize((args.resolution, args.resolution)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
|
])
|
||||||
|
if args.dataset == 'flickr25k':
|
||||||
|
dataset=MirflickrDataset(root_dir=args.flickr25k_root)
|
||||||
|
elif args.dataset == 'coco':
|
||||||
|
img_name, text_name, label_name = get_dataset_filename('train')
|
||||||
|
|
||||||
|
dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root,
|
||||||
|
text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test)
|
||||||
|
elif args.dataset == 'iapr':
|
||||||
|
train_file = os.path.join(args.iapr_root, 'iapr_train')
|
||||||
|
test_file = os.path.join(args.root, 'iapr_test')
|
||||||
|
retrieval_file = os.path.join(args.root, 'iapr_retrieval')
|
||||||
|
train_set=IaprDataset(args,txt=train_file, transform=transform_test)
|
||||||
|
test_set =IaprDataset(args,txt=test_file, transform=transform_test)
|
||||||
|
db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test)
|
||||||
|
train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
|
test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not support.")
|
||||||
|
|
||||||
|
# if args.dataset == 'iapr':
|
||||||
|
|
||||||
|
# elif args.dataset == 'coco':
|
||||||
|
# img_name, text_name, label_name = get_dataset_filename('train')
|
||||||
|
|
||||||
|
# pass
|
||||||
|
# else:
|
||||||
|
# test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75])
|
||||||
|
# train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35])
|
||||||
|
# train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
||||||
|
# test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
# db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
return train_loader, test_loader, db_loader , train_set, test_set, db_set
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,6 @@ captions = {"caption": captions}
|
||||||
scio.savemat(os.path.join(root_dir, "mat/index.mat"), index)
|
scio.savemat(os.path.join(root_dir, "mat/index.mat"), index)
|
||||||
with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f:
|
with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f:
|
||||||
json.dump(captions, f, ensure_ascii=False)
|
json.dump(captions, f, ensure_ascii=False)
|
||||||
# scio.savemat(os.path.join(root_dir, "mat/caption.mat"), captions)
|
|
||||||
scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels)
|
scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
2
main.py
2
main.py
|
|
@ -4,7 +4,7 @@ from train.text_train import Trainer
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
engine=Trainer()
|
engine=Trainer()
|
||||||
# engine.test()
|
engine.test()
|
||||||
engine.train_epoch()
|
engine.train_epoch()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from PIL import Image
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from .model import build_model
|
from .model import build_model
|
||||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,11 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari
|
||||||
from utils.calc_utils import cal_map, cal_pr
|
from utils.calc_utils import cal_map, cal_pr
|
||||||
from dataset.dataloader import dataloader
|
from dataset.dataloader import dataloader
|
||||||
import clip
|
import clip
|
||||||
|
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||||
# from transformers import BertModel
|
# from transformers import BertModel
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
# tokenizer=Tokenizer()
|
||||||
|
|
||||||
def clamp(delta, clean_imgs):
|
def clamp(delta, clean_imgs):
|
||||||
|
|
||||||
|
|
@ -95,6 +97,8 @@ class Trainer(TrainBase):
|
||||||
label_train=[]
|
label_train=[]
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
text=text.to(device, non_blocking=True)
|
text=text.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
text=clip.tokenize(text)
|
||||||
# print(self.model.vocab_size)
|
# print(self.model.vocab_size)
|
||||||
temp_text=self.model.encode_text(text)
|
temp_text=self.model.encode_text(text)
|
||||||
text_train.append(temp_text.cpu().detach().numpy())
|
text_train.append(temp_text.cpu().detach().numpy())
|
||||||
|
|
@ -113,10 +117,10 @@ class Trainer(TrainBase):
|
||||||
return text_representation, text_var_representation
|
return text_representation, text_var_representation
|
||||||
|
|
||||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var
|
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var
|
||||||
,epsilon=0.03125, alpha=3/255, num_iter=1500):
|
,epsilon=0.03125, alpha=3/255):
|
||||||
|
|
||||||
delta = torch.zeros_like(image,requires_grad=True)
|
delta = torch.zeros_like(image,requires_grad=True)
|
||||||
for i in range(num_iter):
|
for i in range(self.args.epochs):
|
||||||
self.model.zero_grad()
|
self.model.zero_grad()
|
||||||
anchor=self.model.encode_image(image+delta)
|
anchor=self.model.encode_image(image+delta)
|
||||||
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
|
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,16 @@ from .base import TrainBase
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||||
from utils.calc_utils import cal_map, cal_pr
|
from utils.calc_utils import cal_map, cal_pr
|
||||||
from dataset.dataloader import dataloader
|
from dataset.dataloader import cross_modal_dataset
|
||||||
import clip
|
import clip
|
||||||
import copy
|
import copy
|
||||||
from model.bert_tokenizer import BertTokenizer
|
from model.bert_tokenizer import BertTokenizer
|
||||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
# from transformers import BertModel
|
# from transformers import BertModel
|
||||||
|
import ftfy
|
||||||
|
import regex as re
|
||||||
|
import html
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
@ -51,6 +54,17 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
|
||||||
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
||||||
filter_words = set(filter_words)
|
filter_words = set(filter_words)
|
||||||
|
|
||||||
|
def basic_clean(text):
|
||||||
|
text = ftfy.fix_text(text)
|
||||||
|
text = html.unescape(html.unescape(text))
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def whitespace_clean(text):
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
return text
|
||||||
|
|
||||||
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
||||||
# substitutes L, k
|
# substitutes L, k
|
||||||
# device = mlm_model.device
|
# device = mlm_model.device
|
||||||
|
|
@ -139,46 +153,10 @@ class Trainer(TrainBase):
|
||||||
def _init_dataset(self):
|
def _init_dataset(self):
|
||||||
self.logger.info("init dataset.")
|
self.logger.info("init dataset.")
|
||||||
self.logger.info(f"Using {self.args.dataset} dataset.")
|
self.logger.info(f"Using {self.args.dataset} dataset.")
|
||||||
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
train_loader, test_loader, db_loader , train_set, test_set, db_set=cross_modal_dataset(self.args)
|
||||||
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
self.train_loader=train_loader
|
||||||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
self.test_loader=test_loader
|
||||||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
self.retrieval_loader=db_loader
|
||||||
indexFile=self.args.index_file,
|
|
||||||
labelFile=self.args.label_file,
|
|
||||||
maxWords=self.args.max_words,
|
|
||||||
imageResolution=self.args.resolution,
|
|
||||||
query_num=self.args.query_num,
|
|
||||||
train_num=self.args.train_num,
|
|
||||||
seed=self.args.seed)
|
|
||||||
self.train_labels = train_data.get_all_label()
|
|
||||||
self.query_labels = query_data.get_all_label()
|
|
||||||
self.retrieval_labels = retrieval_data.get_all_label()
|
|
||||||
self.args.retrieval_num = len(self.retrieval_labels)
|
|
||||||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
|
||||||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
|
||||||
self.train_loader = DataLoader(
|
|
||||||
dataset=train_data,
|
|
||||||
batch_size=self.args.batch_size,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
self.query_loader = DataLoader(
|
|
||||||
dataset=query_data,
|
|
||||||
batch_size=self.args.batch_size,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
self.retrieval_loader = DataLoader(
|
|
||||||
dataset=retrieval_data,
|
|
||||||
batch_size=self.args.batch_size,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
self.train_data=train_data
|
|
||||||
|
|
||||||
|
|
||||||
def generate_mapping(self):
|
def generate_mapping(self):
|
||||||
image_train=[]
|
image_train=[]
|
||||||
|
|
@ -186,6 +164,7 @@ class Trainer(TrainBase):
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
image=image.to(device, non_blocking=True)
|
image=image.to(device, non_blocking=True)
|
||||||
# print(self.model.vocab_size)
|
# print(self.model.vocab_size)
|
||||||
|
text = self.clip_tokenizer.tokenize(text)
|
||||||
temp_image=self.model.encode_image(image)
|
temp_image=self.model.encode_image(image)
|
||||||
image_train.append(temp_image.cpu().detach().numpy())
|
image_train.append(temp_image.cpu().detach().numpy())
|
||||||
label_train.append(label.detach().numpy())
|
label_train.append(label.detach().numpy())
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def get_args():
|
||||||
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
|
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
|
||||||
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
|
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
|
||||||
parser.add_argument("--pretrained", type=str, default="")
|
parser.add_argument("--pretrained", type=str, default="")
|
||||||
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]")
|
parser.add_argument("--dataset", type=str, default="iapr", help="choise from [coco, flckr25k, iapr]")
|
||||||
parser.add_argument("--index-file", type=str, default="index.mat")
|
parser.add_argument("--index-file", type=str, default="index.mat")
|
||||||
parser.add_argument("--caption-file", type=str, default="caption.mat")
|
parser.add_argument("--caption-file", type=str, default="caption.mat")
|
||||||
parser.add_argument("--label-file", type=str, default="label.mat")
|
parser.add_argument("--label-file", type=str, default="label.mat")
|
||||||
|
|
@ -21,7 +21,7 @@ def get_args():
|
||||||
parser.add_argument("--beta", type=float, default=10.0)
|
parser.add_argument("--beta", type=float, default=10.0)
|
||||||
parser.add_argument("--txt-dim", type=int, default=1024)
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=512)
|
parser.add_argument("--output-dim", type=int, default=512)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=500)
|
||||||
parser.add_argument("--max-words", type=int, default=77)
|
parser.add_argument("--max-words", type=int, default=77)
|
||||||
parser.add_argument("--resolution", type=int, default=224)
|
parser.add_argument("--resolution", type=int, default=224)
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue