diff --git a/train/hash_train.py b/train/hash_train.py index b3b13a1..3e3f6bc 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -1,454 +1,364 @@ -from torch.nn.modules import loss -from model.hash_model import DCMHT as DCMHT -import os -from tqdm import tqdm -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -import scipy.io as scio -import numpy as np - -from .base import TrainBase -from torch.optim import Adam -import torch.nn.functional as F -# from model.optimization import BertAdam -# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss -from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices -from utils.calc_utils import calc_map_k_matrix as calc_map_k -from dataset.dataloader import dataloader -import open_clip -# from transformers import BertModel - -def clamp(delta, clean_imgs): - - clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) - clamp_delta = clamp_imgs - clean_imgs.data - - return clamp_delta - - - -class Trainer(TrainBase): - - def __init__(self, - rank=0): - args = get_args() - super(Trainer, self).__init__(args, rank) - self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) - text_mean_representation, text_var_representation=self.generate_mapping() - self.text_mean=text_mean_representation - self.text_var=text_var_representation - # self.run() - - def _init_model(self): - self.logger.info("init model.") - # self.generator=Generator() - # linear = False - # if self.args.hash_layer == "linear": - # linear = True - # self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank) - # self.bert.eval() - # self.logger.info("ViT+GPT!") - # HashModel = DCMHT - # if self.args.victim_model == 'JDSH': - # from model.JDSH import TxtNet, ImgNet - # # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank) - # self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank) - # path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth') - # checkpoint=torch.load(path) - # self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}")) - # self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}")) - # self.img_model.eval() - # self.txt_model.eval() - # elif self.args.victim_model == 'DJSRH': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - # elif self.args.victim_model == 'SSAH': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - # elif self.args.victim_model == 'DCHUC': - # self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, - # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) - # self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - - # if self.args.pretrained != "" and os.path.exists(self.args.pretrained): - # self.logger.info("load pretrained model.") - # self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) - model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device) - self.model= model_clip - self.model.eval() - self.model.float() - self.optimizer =Adam(self.model.visual.parameters,lr=self.args.lr ,betas=[0.9,0.98] ) - # self.optimizer = BertAdam([ - # {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, - # {'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, - # {'params': self.model.text_hash.parameters(), 'lr': self.args.lr} - # ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', - # b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs, - # weight_decay=self.args.weight_decay, max_grad_norm=1.0) - - # print(self.model) - - def _init_dataset(self): - self.logger.info("init dataset.") - self.logger.info(f"Using {self.args.dataset} dataset.") - self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) - self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) - self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) - train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, - indexFile=self.args.index_file, - labelFile=self.args.label_file, - maxWords=self.args.max_words, - imageResolution=self.args.resolution, - query_num=self.args.query_num, - train_num=self.args.train_num, - seed=self.args.seed) - self.train_labels = train_data.get_all_label() - self.query_labels = query_data.get_all_label() - self.retrieval_labels = retrieval_data.get_all_label() - # self.args.retrieval_num = len(self.retrieval_labels) - self.logger.info(f"query shape: {self.query_labels.shape}") - self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") - self.train_loader = DataLoader( - dataset=train_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.query_loader = DataLoader( - dataset=query_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.retrieval_loader = DataLoader( - dataset=retrieval_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - def generate_mapping(self): - text_train=[] - label_train=[] - # image_train=[] - # self.change_state(mode="valid") - for image, text, label, index in self.train_loader: - # image=image.to(self.device, non_blocking=True) - text=text.to(self.device, non_blocking=True) - temp_text=self.model.encode_text(text) - # temp_image=self.model.encode_image(image) - # image_train.append(temp_image.cpu().detach().numpy()) - text_train.append(temp_text.cpu().detach().numpy()) - label_train.append(label.detach().numpy()) - text_train=np.concatenate(text_train, axis=0) - # image_train=np.concatenate(image_train, axis=0) - label_train=np.concatenate(label_train, axis=0) - label_unipue=np.unique(label_train,axis=0) - # image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) - text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) - text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) - - text_mean_representation = {} - text_var_representation = {} - for i, centroid in enumerate(label_unipue): - text_mean_representation[centroid.tobytes()] = text_centroids[i] - text_var_representation[centroid.tobytes()]= text_var[i] - return text_mean_representation, text_var_representation - - def target_adv(self, image, positive, positive_mean,positive_var, negative, negative_mean, negative_var, - epsilon=0.03125, alpha=3/255, num_iter=100): - - delta = torch.zeros_like(image,requires_grad=True) - # clean_output = self.model.encode_image(image) - one=torch.zeros_like(positive) - alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) - for i in range(num_iter): - self.model.zero_grad() - anchor=self.model.encode_image(image+delta) - loss1=alienation_loss(anchor, positive, negative) - loss=loss1 + self.args.beta * self.distribution_loss(anchor,positive_mean,positive_var,negative_mean, negative_var) - - loss.backward(retain_graph=True) - delta.data = delta - alpha * delta.grad.detach().sign() - delta.data =clamp(delta, image).clamp(-epsilon, epsilon) - delta.grad.zero_() - - return delta.detach() - - def train_epoch(self): - self.change_state(mode="valid") - # self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) - all_loss = 0 - times = 0 - save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") - os.makedirs(save_dir, exist_ok=True) - # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) - retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) - adv_images=[] - adv_labels=[] - # target_texts=[] - for image, text, label, index in self.train_loader: - self.global_step += 1 - times += 1 - image.float() - if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: - label = torch.ones([image.shape[0]], dtype=torch.int) - label = label.diag() - image = image.to(self.rank, non_blocking=True) - text = text.to(self.rank, non_blocking=True) - index = index.numpy() - # image_anchor=self.image_representation(label.detach().cpu().numpy()) - negetive_mean=self.text_mean(label.detach().cpu().numpy()) - negetive_var=self.text_var(label.detach().cpu().numpy()) - # negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) - negetive_code=self.model.encode_text(text) - target_label=label.flip(dims=[0]) - # target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) - positive_mean=self.text_mean(target_label.detach().cpu().numpy()) - positive_var=self.text_var(target_label.detach().cpu().numpy()) - # positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) - positive_code=self.model.encode_text(text.flip(dims=[0])) - # print("text shape:", text.shape) - # index = index.numpy() - # print(text.shape) - delta=self.target_adv(image,positive_code,torch.from_numpy(positive_mean).to(self.rank, non_blocking=True), torch.from_numpy(positive_var).to(self.rank, non_blocking=True), - negetive_code, torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True), torch.from_numpy(negetive_var).to(self.rank, non_blocking=True)) - adv_image=delta+image - adv_images.append(self.model.encode_image(adv_image)) - adv_labels.append(target_label) - # target_texts.append(self.model.encode_text(text)) - adv_image=torch.cat(adv_image).to(self.device) - adv_labels=torch.cat(adv_labels).to(self.device) - - mAPi2t = calc_map_k(adv_images, retrieval_txt, adv_labels, self.retrieval_labels, None, self.rank) - mAPt2t = calc_map_k(adv_images, retrieval_img, adv_labels, self.retrieval_labels, None, self.rank) - self.logger.info(f">>>>>> t-MAP(i->t): {mAPi2t}, t-MAP(t->t): {mAPt2t}") - adv_images = adv_images.cpu().detach().numpy() - # query_txt = query_txt.cpu().detach().numpy() - retrieval_img = retrieval_img.cpu().detach().numpy() - retrieval_txt = retrieval_txt.cpu().detach().numpy() - adv_labels = adv_labels.numpy() - retrieval_labels = self.retrieval_labels.numpy() - - result_dict = { - 'adv_img': adv_images, - # 'q_txt': query_txt, - 'r_img': retrieval_img, - 'r_txt': retrieval_txt, - 'adv_l': adv_labels, - 'r_l': retrieval_labels - } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) - self.logger.info(">>>>>> save all data!") - # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") - # return adv_images, texts, adv_labels - - - - def train(self): - self.logger.info("Start train.") - self.valid() - self.train_epoch() - # self.valid() - # for epoch in range(self.args.epochs): - # self.train_epoch() - # self.valid(epoch) - # self.save_model(epoch) - - self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") - - def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): - - s = torch.matmul(a, b.t()) - b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s))) - - return b_loss - - def distribution_loss(self, x: torch.Tensor, positive_mean,positive_var, negative_mean, negative_var): - """ - """ - norm_fun= lambda mean, var, x: 50- torch.mean(torch.exp(-(x-mean) **2 /(2*var)) /(2* torch.pi * var)) - positive_distribution=norm_fun(positive_mean,positive_var,x) - negative_distribution=norm_fun(negative_mean,negative_var,x) - # alienation_loss=nn.MarginRankingLoss() - - return F.margin_ranking_loss(positive_distribution,negative_distribution,1) - - - def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05): - - # $\vartheta$ - vartheta = self.args.vartheta - if self.args.sim_threshold != 0: - threshold = self.args.sim_threshold - similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b) - - positive_similarity = similarity * label_sim - # 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。 - negative_similarity = similarity * (1 - label_sim) - - if self.args.similarity_function == "cosine": - positive_similarity = positive_similarity.clip(threshold) - threshold - negative_similarity = negative_similarity.clip(max=1.) - negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity - elif self.args.similarity_function == "euclidean": - # 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值 - # 人为认定的最大值是一半长度的hash码不同。 - max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5 - negative_similarity = negative_similarity.clip(max=max_value) - negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity - - if self.args.loss_type == "l1": - positive_loss = positive_similarity.mean() - negative_loss = negative_similarity.mean() - elif self.args.loss_type == "l2": - positive_loss = torch.pow(positive_similarity, 2).mean() - negative_loss = torch.pow(negative_similarity, 2).mean() - else: - raise ValueError("argument of loss_type is not support.") - - return similarity, positive_loss, negative_loss - - def make_hash_code(self, code: list) -> torch.Tensor: - - code = torch.stack(code) - # print(code.shape) - code = code.permute(1, 0, 2) - hash_code = torch.argmax(code, dim=-1) - hash_code[torch.where(hash_code == 0)] = -1 - hash_code = hash_code.float() - - return hash_code - - def get_code(self, data_loader, length: int): - - img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) - text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) - - for image, text, label, index in tqdm(data_loader): - image = image.to(self.rank, non_blocking=True) - text = text.to(self.rank, non_blocking=True) - index = index.numpy() - image_hash=self.model.encode_image(image) - # text_feat=self.bert(text)[0] - text_hash=self.model.encode_text(text) - img_buffer[index, :] = image_hash.data - text_buffer[index, :] = text_hash.data - - return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) - - def get_adv_code(self, adv_data_list,text_list): - - img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank) - text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank) - - for i in tqdm(range(len(adv_data_list))): - image = adv_data_list[i].to(self.rank, non_blocking=True) - text = text_list[i].to(self.rank, non_blocking=True) - # index = index.numpy() - image_hash=self.img_model(image) - text_feat=self.bert(text)[0] - text_hash=self.txt_model(text_feat) - # text_hash = self.make_hash_code(text_hash) - # image_hash.to(self.rank) - # text_hash.to(self.rank) - img_buffer[i, :] = image_hash.data - text_buffer[i, :] = text_hash.data - - return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) - - - def valid_attack(self,adv_images, texts, adv_labels): - save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") - os.makedirs(save_dir, exist_ok=True) - - - - def test(self, mode_name="i2t"): - if self.args.pretrained == "": - raise RuntimeError("test step must load a model! please set the --pretrained argument.") - self.change_state(mode="valid") - save_dir = os.path.join(self.args.save_dir, "PR_cruve") - os.makedirs(save_dir, exist_ok=True) - query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) - retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) - mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # print("map map") - mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - self.max_mapt2i = max(self.max_mapt2i, mAPt2i) - self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}") - - query_img = query_img.cpu().detach().numpy() - query_txt = query_txt.cpu().detach().numpy() - retrieval_img = retrieval_img.cpu().detach().numpy() - retrieval_txt = retrieval_txt.cpu().detach().numpy() - query_labels = self.query_labels.numpy() - retrieval_labels = self.retrieval_labels.numpy() - - result_dict = { - 'q_img': query_img, - 'q_txt': query_txt, - 'r_img': retrieval_img, - 'r_txt': retrieval_txt, - 'q_l': query_labels, - 'r_l': retrieval_labels - } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) - self.logger.info(">>>>>> save all data!") - - - def valid(self, epoch): - self.logger.info("Valid.") - # self.change_state(mode="valid") - query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) - retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) - # print("get all code") - mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # print("map map") - mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - if self.max_mapi2t < mAPi2t: - self.best_epoch_i = epoch - self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") - self.max_mapi2t = max(self.max_mapi2t, mAPi2t) - if self.max_mapt2i < mAPt2i: - self.best_epoch_t = epoch - self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") - self.max_mapt2i = max(self.max_mapt2i, mAPt2i) - self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ - MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") - - def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): - - save_dir = os.path.join(self.args.save_dir, "PR_cruve") - os.makedirs(save_dir, exist_ok=True) - - query_img = query_img.cpu().detach().numpy() - query_txt = query_txt.cpu().detach().numpy() - retrieval_img = retrieval_img.cpu().detach().numpy() - retrieval_txt = retrieval_txt.cpu().detach().numpy() - query_labels = self.query_labels.numpy() - retrieval_labels = self.retrieval_labels.numpy() - - result_dict = { - 'q_img': query_img, - 'q_txt': query_txt, - 'r_img': retrieval_img, - 'r_txt': retrieval_txt, - 'q_l': query_labels, - 'r_l': retrieval_labels - } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) - self.logger.info(f">>>>>> save best {mode_name} data!") - - +from torch.nn.modules import loss +from model.hash_model import DCMHT as DCMHT +import os +from tqdm import tqdm +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import scipy.io as scio +import numpy as np + +from .base import TrainBase +from model.optimization import BertAdam +# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss +from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices +from utils.calc_utils import calc_map_k_matrix as calc_map_k +from dataset.dataloader import dataloader +import open_clip +# from transformers import BertModel + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def clamp(delta, clean_imgs): + + clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) + clamp_delta = clamp_imgs - clean_imgs.data + + return clamp_delta + +class Trainer(TrainBase): + + def __init__(self, + rank=0): + args = get_args() + super(Trainer, self).__init__(args, rank) + self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) + text_representation, text_representation=self.generate_mapping() + self.image_representation=text_representation + self.text_representation=text_representation + self.device=rank + # self.run() + + def _init_model(self): + self.logger.info("init model.") + model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device) + self.model= model_clip + self.model.eval() + self.model.float() + + def _init_dataset(self): + self.logger.info("init dataset.") + self.logger.info(f"Using {self.args.dataset} dataset.") + self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) + self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) + self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) + train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, + indexFile=self.args.index_file, + labelFile=self.args.label_file, + maxWords=self.args.max_words, + imageResolution=self.args.resolution, + query_num=self.args.query_num, + train_num=self.args.train_num, + seed=self.args.seed) + self.train_labels = train_data.get_all_label() + self.query_labels = query_data.get_all_label() + self.retrieval_labels = retrieval_data.get_all_label() + self.args.retrieval_num = len(self.retrieval_labels) + self.logger.info(f"query shape: {self.query_labels.shape}") + self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") + self.train_loader = DataLoader( + dataset=train_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.query_loader = DataLoader( + dataset=query_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.retrieval_loader = DataLoader( + dataset=retrieval_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + + + + def generate_mapping(self): + text_train=[] + label_train=[] + for image, text, label, index in self.train_loader: + text=text.to(device, non_blocking=True) + # print(self.model.vocab_size) + temp_text=self.model.encode_text(text) + text_train.append(temp_text.cpu().detach().numpy()) + label_train.append(label.detach().numpy()) + text_train=np.concatenate(text_train, axis=0) + label_train=np.concatenate(label_train, axis=0) + label_unipue=np.unique(label_train,axis=0) + text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) + text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) + + text_representation = {} + text_var_representation = {} + for i, centroid in enumerate(label_unipue): + text_representation[centroid.tobytes()] = text_centroids[i] + text_var_representation[centroid.tobytes()]= text_var[i] + return text_representation, text_var_representation + + def target_adv(self, image, positive, negative, + epsilon=0.03125, alpha=3/255, num_iter=100): + + delta = torch.zeros_like(image,requires_grad=True) + one=torch.zeros_like(positive) + alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + for i in range(num_iter): + self.model.zero_grad() + anchor=self.model.encode_image(image+delta) + loss=alienation_loss(anchor, positive, negative) + + + loss.backward(retain_graph=True) + delta.data = delta - alpha * delta.grad.detach().sign() + delta.data =clamp(delta, image).clamp(-epsilon, epsilon) + delta.grad.zero_() + + return delta.detach() + + def train_epoch(self, epoch): + self.change_state(mode="valid") + self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) + all_loss = 0 + times = 0 + adv_images=[] + adv_labels=[] + texts=[] + for image, text, label, index in self.train_loader: + self.global_step += 1 + times += 1 + image.float() + if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: + label = torch.ones([image.shape[0]], dtype=torch.int) + label = label.diag() + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + index = index.numpy() + image_anchor=self.image_representation(label.detach().cpu().numpy()) + text_anchor=self.text_representation(label.detach().cpu().numpy()) + negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) + target_label=label.flip(dims=[0]) + target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) + target_text_anchor=self.text_representation(target_label.detach().cpu().numpy()) + positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) + delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True), + torch.from_numpy(negetive_code).to(self.rank, non_blocking=True)) + adv_image=delta+image + adv_images.append(adv_image) + adv_labels.append(target_label) + texts.append(text) + return adv_images, texts, adv_labels + + + + def train(self): + self.logger.info("Start train.") + + for epoch in range(self.args.epochs): + self.train_epoch(epoch) + self.valid(epoch) + self.save_model(epoch) + + self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") + + def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): + + s = torch.matmul(a, b.t()) + b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s))) + + return b_loss + + def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): + """ + """ + kl_divergence = torch.mean(a * torch.log(a / (b + 0.001))) + print("mean", torch.mean(a - b)) + print("kl", kl_divergence) + return kl_divergence + + + def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05): + + # $\vartheta$ + vartheta = self.args.vartheta + if self.args.sim_threshold != 0: + threshold = self.args.sim_threshold + similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b) + + positive_similarity = similarity * label_sim + # 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。 + negative_similarity = similarity * (1 - label_sim) + + if self.args.similarity_function == "cosine": + positive_similarity = positive_similarity.clip(threshold) - threshold + negative_similarity = negative_similarity.clip(max=1.) + negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity + elif self.args.similarity_function == "euclidean": + # 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值 + # 人为认定的最大值是一半长度的hash码不同。 + max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5 + negative_similarity = negative_similarity.clip(max=max_value) + negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity + + if self.args.loss_type == "l1": + positive_loss = positive_similarity.mean() + negative_loss = negative_similarity.mean() + elif self.args.loss_type == "l2": + positive_loss = torch.pow(positive_similarity, 2).mean() + negative_loss = torch.pow(negative_similarity, 2).mean() + else: + raise ValueError("argument of loss_type is not support.") + + return similarity, positive_loss, negative_loss + + def make_hash_code(self, code: list) -> torch.Tensor: + + code = torch.stack(code) + # print(code.shape) + code = code.permute(1, 0, 2) + hash_code = torch.argmax(code, dim=-1) + hash_code[torch.where(hash_code == 0)] = -1 + hash_code = hash_code.float() + + return hash_code + + def get_code(self, data_loader, length: int): + + img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + + for image, text, label, index in tqdm(data_loader): + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + index = index.numpy() + image_hash=self.img_model(image) + text_feat=self.bert(text)[0] + text_hash=self.txt_model(text_feat) + img_buffer[index, :] = image_hash.data + text_buffer[index, :] = text_hash.data + + return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) + + def get_adv_code(self, adv_data_list,text_list): + + img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank) + text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank) + + for i in tqdm(range(len(adv_data_list))): + image = adv_data_list[i].to(self.rank, non_blocking=True) + text = text_list[i].to(self.rank, non_blocking=True) + # index = index.numpy() + image_hash=self.img_model(image) + text_feat=self.bert(text)[0] + text_hash=self.txt_model(text_feat) + # text_hash = self.make_hash_code(text_hash) + # image_hash.to(self.rank) + # text_hash.to(self.rank) + img_buffer[i, :] = image_hash.data + text_buffer[i, :] = text_hash.data + + return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) + + + def valid_attack(self,adv_images, texts, adv_labels): + save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") + os.makedirs(save_dir, exist_ok=True) + + + + def test(self, mode_name="i2t"): + if self.args.pretrained == "": + raise RuntimeError("test step must load a model! please set the --pretrained argument.") + self.change_state(mode="valid") + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) + mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # print("map map") + mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + self.max_mapt2i = max(self.max_mapt2i, mAPt2i) + self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}") + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + self.logger.info(">>>>>> save all data!") + + + def valid(self, epoch): + self.logger.info("Valid.") + self.change_state(mode="valid") + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) + # print("get all code") + mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # print("map map") + mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + if self.max_mapi2t < mAPi2t: + self.best_epoch_i = epoch + self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") + self.max_mapi2t = max(self.max_mapi2t, mAPi2t) + if self.max_mapt2i < mAPt2i: + self.best_epoch_t = epoch + self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") + self.max_mapt2i = max(self.max_mapt2i, mAPt2i) + self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ + MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") + + def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): + + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + self.logger.info(f">>>>>> save best {mode_name} data!") + +