This commit is contained in:
leewlving 2024-06-21 10:12:04 +08:00
parent 9539ba5aed
commit 5b84447a3d
8 changed files with 292 additions and 51 deletions

View File

@ -9,6 +9,10 @@ import random
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
import re
import os
import numpy as np
from torchvision import transforms
@ -62,7 +66,28 @@ class BaseDataset(Dataset):
return image
def pre_caption(self, caption):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
caption_words = caption.split(' ')
if len(caption_words)>self.maxWords:
caption = ' '.join(caption_words[:self.maxWords])
return caption
def _load_text(self, index: int):
# print(self.captions[index])
# captions = self.pre_caption(self.captions[index])
# captions =whitespace_clean(self.captions[index]).lower()
captions = self.captions[index]
use_cap = captions[random.randint(0, len(captions) - 1)]
@ -79,7 +104,7 @@ class BaseDataset(Dataset):
caption.append(0)
caption = torch.tensor(caption)
return caption
return captions
def _load_label(self, index: int) -> torch.Tensor:
label = self.labels[index]
@ -101,3 +126,175 @@ class BaseDataset(Dataset):
return image, caption, label, index
# def default_loader(path):
# return Image.open(path).convert('RGB')
# class IaprDataset(Dataset):
# def __init__(self, args,txt,transform=None, loader=default_loader):
# self.transform = transform
# self.loader = loader
# name_label = []
# for line in open(txt):
# line = line.strip('\n').split()
# label = list(map(int, np.array(line[len(line)-255:]))) #后255个二进制码是label的前2912个是单词在词袋中的二进制编码
# tem = re.split('[/.]', line[0])
# file_name, sample_name = tem[0], tem[1]
# name_label.append([file_name, sample_name, label])
# # # print('label = ', label)
# # print('file_name = %s, sample_name = %s' %(file_name, sample_name))
# # label_list = np.where(label=='1')
# # print('label_list = ', label_list)
# self.name_label = name_label
# self.image_dir=args.image_dir
# self.text_dir = args.text_dir
# def __getitem__(self, index):
# words = self.name_label[index] # words = [file_name, sample_name, label]
# # print('words = ', words[0:2])
# img_path = os.path.join(self.image_dir, words[0], words[1]+'.jpg')
# text_path = os.path.join(self.text_dir, words[0], words[1]+'.txt')
# # img
# img = self.loader(img_path)
# if self.transform is not None:
# img = self.transform(img)
# # text
# text = 'None'
# for line in open(text_path):
# text = '[CLS]' + line + '[SEP]'
# # label
# label = torch.LongTensor(words[2])
# # image, caption, label, index
# return img, text, label, index
# def __len__(self):
# return len(self.name_label)
default_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class MScocoDataset(Dataset):
def __init__(self, data_path, img_filename, text_filename, label_filename, transform=None):
self.data_path = data_path
if transform is None:
self.transform = default_transform
else:
self.transform = transform
img_filepath = os.path.join(data_path, img_filename)
with open(img_filepath, 'r') as f:
self.imgs = [x.strip() for x in f]
text_filepath = os.path.join(data_path, text_filename)
with open(text_filepath, 'r') as f:
self.texts = f.readlines()
self.texts = [i.replace('\n', '') for i in self.texts]
label_filepath = os.path.join(data_path, label_filename)
self.labels = np.genfromtxt(label_filepath, dtype=np.int32)
def __getitem__(self, index):
img = Image.open(os.path.join(self.data_path, self.imgs[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.from_numpy(self.labels[index]).float()
text = self.texts[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(self.imgs)
class MirflickrDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
file_path = os.path.join(root_dir, "mirflickr25k_annotations_v080")
file_list = os.listdir(file_path)
file_list = [item for item in file_list if "_r1" not in item and "README" not in item]
self.class_index = {}
for i, item in enumerate(file_list):
self.class_index.update({item: i})
self.label_dict = {}
for path_id in file_list:
path = os.path.join(file_path, path_id)
with open(path, "r") as f:
for item in f:
item = item.strip()
if item not in self.label_dict:
label = np.zeros(len(file_list))
label[self.class_index[path_id]] = 1
self.label_dict.update({item: label})
else:
# print()
self.label_dict[item][self.class_index[path_id]] = 1
self.captions_dict = {}
captions_path = os.path.join(root_dir, "mirflickr/meta/tags")
captions_list = os.listdir(captions_path)
for item in captions_list:
id_ = item.split(".")[0].replace("tags", "")
caption = ""
with open(os.path.join(captions_path, item), "r") as f:
for word in f.readlines():
caption += word.strip() + " "
caption = caption.strip()
self.captions_dict.update({id_: caption})
if transform is None:
self.transform = default_transform
else:
self.transform = transform
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))
class NusWideDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
imageListFile = os.path.join(root_dir, "/Low-Level-Features/ImageList/Imagelist.txt")
labelPath = os.path.join(root_dir, "/nuswide/Groundtruth/AllLabels")
textFile = os.path.join(root_dir, "/Low-Level-Features/NUS_WID_Tags/All_Tags.txt")
def __getitem__(self, index):
label=self.label_dict[index]
PATH = os.path.join(self.root_dir, "mirflickr")
img=Image.open(os.path.join(PATH, "im" + index + ".jpg"))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
# label = torch.from_numpy(self.labels[index]).float()
text = self.captions_dict[index]
# image, caption, label, index
return img, text, label, index
def __len__(self):
return len(list(self.label_dict.keys()))

View File

@ -1,8 +1,12 @@
from .base import BaseDataset
from .base import BaseDataset, MirflickrDataset, MScocoDataset , IaprDataset
import os
import numpy as np
import scipy.io as scio
import torch
import torchvision.transforms as transforms
import json
from torch.utils.data import Dataset
def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=None):
@ -45,8 +49,13 @@ def dataloader(captionFile: str,
with open(captionFile, "r") as f:
captions = f.readlines()
captions = np.asarray([[item.strip()] for item in captions])
elif captionFile.endswith("json"):
with open(captionFile, "r") as f:
data = json.load(f)
captions=data["caption"]
# captions = captions[0] if captions.shape[0] == 1 else captions
else:
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, mat] format.")
raise ValueError("the format of 'captionFile' doesn't support, only support [txt, json, mat] format.")
if not npy:
indexs = scio.loadmat(indexFile)["index"]
else:
@ -65,4 +74,55 @@ def dataloader(captionFile: str,
return train_data, query_data, retrieval_data
def get_dataset_filename(split):
filename = {
'train': ('cm_train_imgs.txt', 'cm_train_txts.txt', 'cm_train_labels.txt'),
'test': ('cm_test_imgs.txt', 'cm_test_txts.txt', 'cm_test_labels.txt'),
'db': ('cm_database_imgs.txt', 'cm_database_txts.txt', 'cm_database_labels.txt')
}
return filename[split]
def cross_modal_dataset(args ):
transform_test = transforms.Compose([
transforms.Resize((args.resolution, args.resolution)),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
if args.dataset == 'flickr25k':
dataset=MirflickrDataset(root_dir=args.flickr25k_root)
elif args.dataset == 'coco':
img_name, text_name, label_name = get_dataset_filename('train')
dataset=MScocoDataset(data_path=args.coco_root,img_filename=args.coco_img_root,
text_filename=args.coco_txt_root,label_filename=args.coco_label_root,transform=transform_test)
elif args.dataset == 'iapr':
train_file = os.path.join(args.iapr_root, 'iapr_train')
test_file = os.path.join(args.root, 'iapr_test')
retrieval_file = os.path.join(args.root, 'iapr_retrieval')
train_set=IaprDataset(args,txt=train_file, transform=transform_test)
test_set =IaprDataset(args,txt=test_file, transform=transform_test)
db_set=IaprDataset(args,txt=retrieval_file, transform=transform_test)
train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
else:
raise ValueError("Not support.")
# if args.dataset == 'iapr':
# elif args.dataset == 'coco':
# img_name, text_name, label_name = get_dataset_filename('train')
# pass
# else:
# test_set, db_set=torch.utils.data.random_split(dataset,[0.25, 0.75])
# train_set,_=torch.utils.data.random_split(dataset,[0.65, 0.35])
# train_loader = Dataset.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
# test_loader = Dataset.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
# db_loader = Dataset.DataLoader(db_set, batch_size=args.batch_size, shuffle=False, num_workers=4)
return train_loader, test_loader, db_loader , train_set, test_set, db_set

View File

@ -76,7 +76,6 @@ captions = {"caption": captions}
scio.savemat(os.path.join(root_dir, "mat/index.mat"), index)
with open(os.path.join(root_dir, "mat/caption.json"), 'w', encoding='utf-8') as f:
json.dump(captions, f, ensure_ascii=False)
# scio.savemat(os.path.join(root_dir, "mat/caption.mat"), captions)
scio.savemat(os.path.join(root_dir, "mat/label.mat"), labels)

View File

@ -4,7 +4,7 @@ from train.text_train import Trainer
if __name__ == "__main__":
engine=Trainer()
# engine.test()
engine.test()
engine.train_epoch()

View File

@ -9,6 +9,8 @@ from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

View File

@ -15,9 +15,11 @@ from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similari
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
import clip
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# tokenizer=Tokenizer()
def clamp(delta, clean_imgs):
@ -95,6 +97,8 @@ class Trainer(TrainBase):
label_train=[]
for image, text, label, index in self.train_loader:
text=text.to(device, non_blocking=True)
text=clip.tokenize(text)
# print(self.model.vocab_size)
temp_text=self.model.encode_text(text)
text_train.append(temp_text.cpu().detach().numpy())
@ -113,10 +117,10 @@ class Trainer(TrainBase):
return text_representation, text_var_representation
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var
,epsilon=0.03125, alpha=3/255, num_iter=1500):
,epsilon=0.03125, alpha=3/255):
delta = torch.zeros_like(image,requires_grad=True)
for i in range(num_iter):
for i in range(self.args.epochs):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())

View File

@ -13,13 +13,16 @@ from .base import TrainBase
from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
from dataset.dataloader import cross_modal_dataset
import clip
import copy
from model.bert_tokenizer import BertTokenizer
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import BertForMaskedLM
# from transformers import BertModel
import ftfy
import regex as re
import html
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -51,6 +54,17 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words)
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# substitutes L, k
# device = mlm_model.device
@ -139,46 +153,10 @@ class Trainer(TrainBase):
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.train_data=train_data
train_loader, test_loader, db_loader , train_set, test_set, db_set=cross_modal_dataset(self.args)
self.train_loader=train_loader
self.test_loader=test_loader
self.retrieval_loader=db_loader
def generate_mapping(self):
image_train=[]
@ -186,6 +164,7 @@ class Trainer(TrainBase):
for image, text, label, index in self.train_loader:
image=image.to(device, non_blocking=True)
# print(self.model.vocab_size)
text = self.clip_tokenizer.tokenize(text)
temp_image=self.model.encode_image(image)
image_train.append(temp_image.cpu().detach().numpy())
label_train.append(label.detach().numpy())

View File

@ -9,7 +9,7 @@ def get_args():
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
parser.add_argument("--pretrained", type=str, default="")
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]")
parser.add_argument("--dataset", type=str, default="iapr", help="choise from [coco, flckr25k, iapr]")
parser.add_argument("--index-file", type=str, default="index.mat")
parser.add_argument("--caption-file", type=str, default="caption.mat")
parser.add_argument("--label-file", type=str, default="label.mat")
@ -21,7 +21,7 @@ def get_args():
parser.add_argument("--beta", type=float, default=10.0)
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--epochs", type=int, default=500)
parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=8)