diff --git a/train/text_train.py b/train/text_train.py new file mode 100644 index 0000000..82c3d0f --- /dev/null +++ b/train/text_train.py @@ -0,0 +1,333 @@ +from torch.nn.modules import loss +# from model.hash_model import DCMHT as DCMHT +import os +from tqdm import tqdm +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import torch.utils.data as data +import scipy.io as scio +import numpy as np + +from .base import TrainBase +from torch.nn import functional as F +from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices +from utils.calc_utils import cal_map, cal_pr +from dataset.dataloader import dataloader +import clip +# from transformers import BertModel + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +def clamp(delta, clean_imgs): + + clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) + clamp_delta = clamp_imgs - clean_imgs.data + + return clamp_delta + +class Trainer(TrainBase): + + def __init__(self, + rank=0): + args = get_args() + super(Trainer, self).__init__(args, rank) + self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) + image_mean, image_var=self.generate_mapping() + self.image_mean=image_mean + self.image_var=image_var + self.device=rank + # self.run() + + def _init_model(self): + self.logger.info("init model.") + model_clip, preprocess = clip.load(self.args.victim, device=device) + self.model= model_clip + self.model.eval() + self.model.float() + + def _init_dataset(self): + self.logger.info("init dataset.") + self.logger.info(f"Using {self.args.dataset} dataset.") + self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) + self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) + self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) + train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, + indexFile=self.args.index_file, + labelFile=self.args.label_file, + maxWords=self.args.max_words, + imageResolution=self.args.resolution, + query_num=self.args.query_num, + train_num=self.args.train_num, + seed=self.args.seed) + self.train_labels = train_data.get_all_label() + self.query_labels = query_data.get_all_label() + self.retrieval_labels = retrieval_data.get_all_label() + self.args.retrieval_num = len(self.retrieval_labels) + self.logger.info(f"query shape: {self.query_labels.shape}") + self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") + self.train_loader = DataLoader( + dataset=train_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.query_loader = DataLoader( + dataset=query_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.retrieval_loader = DataLoader( + dataset=retrieval_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.train_data=train_data + + + def generate_mapping(self): + image_train=[] + label_train=[] + for image, text, label, index in self.train_loader: + image=image.to(device, non_blocking=True) + # print(self.model.vocab_size) + temp_image=self.model.encode_image(image) + image_train.append(temp_image.cpu().detach().numpy()) + label_train.append(label.detach().numpy()) + image_train=np.concatenate(image_train, axis=0) + label_train=np.concatenate(label_train, axis=0) + label_unipue=np.unique(label_train,axis=0) + image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) + image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) + + image_representation = {} + image_var_representation = {} + for i, centroid in enumerate(label_unipue): + image_representation[str(centroid.astype(int))] = image_centroids[i] + image_var_representation[str(centroid.astype(int))]= image_var[i] + return image_representation, image_var_representation + + def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, + beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): + + delta = torch.zeros_like(image,requires_grad=True) + # one=torch.zeros_like(positive) + # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) + for i in range(num_iter): + self.model.zero_grad() + anchor=self.model.encode_image(image+delta) + loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity()) + negative_dist=(anchor-negetive_mean)**2 / negative_var + positive_dist=(anchor-positive_mean)**2 /positive_var + negatives=torch.exp(negative_dist / temperature) + positives= torch.exp(positive_dist / temperature) + loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1 + loss.backward(retain_graph=True) + delta.data = delta - alpha * delta.grad.detach().sign() + delta.data =clamp(delta, image).clamp(-epsilon, epsilon) + delta.grad.zero_() + adv_code=self.model.encode_image(image+delta) + return delta.detach() , adv_code + + def train_epoch(self): + self.change_state(mode="valid") + save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") + all_loss = 0 + times = 0 + adv_codes=[] + adv_label=[] + for image, text, label, index in self.train_loader: + self.global_step += 1 + times += 1 + print(times) + image.float() + image = image.to(self.rank, non_blocking=True) + text = text.to(self.rank, non_blocking=True) + negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) + negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) + negetive_code=self.model.encode_image(image) + + #targeted sample + np.random.seed(times) + select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) + target_dataset = data.Subset(self.train_data, select_index) + target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) + target_image, _, target_label, _ = next(iter(target_subset)) + target_image=target_image.to(self.rank, non_blocking=True) + positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) + positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) + positive_code=self.model.encode_image(target_image) + + + delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var, + positive_code,positive_mean,positive_var) + adv_codes.append(adv_code.cpu().detach().numpy()) + adv_label.append(target_label.numpy()) + adv_img=np.concatenate(adv_codes) + adv_labels=np.concatenate(adv_label) + + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) + + + + retrieval_txt = retrieval_txt.cpu().detach().numpy() + retrieval_labels = self.retrieval_labels.numpy() + + + mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) + # pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels) + # pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels) + self.logger.info(f">>>>>> MAP_t: {mAP_t}") + result_dict = { + 'adv_img': adv_img, + 'r_txt': retrieval_txt, + 'adv_l': adv_labels, + 'r_l': retrieval_labels + # 'q_l':query_labels + # 'pr': pr, + # 'pr_t': pr_t + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) + self.logger.info(">>>>>> save all data!") + + + + + + + def train(self): + self.logger.info("Start train.") + + for epoch in range(self.args.epochs): + self.train_epoch(epoch) + self.valid(epoch) + self.save_model(epoch) + + self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") + + + + def make_hash_code(self, code: list) -> torch.Tensor: + + code = torch.stack(code) + # print(code.shape) + code = code.permute(1, 0, 2) + hash_code = torch.argmax(code, dim=-1) + hash_code[torch.where(hash_code == 0)] = -1 + hash_code = hash_code.float() + + return hash_code + + def get_code(self, data_loader, length: int): + + img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) + + for image, text, label, index in tqdm(data_loader): + image = image.to(self.device, non_blocking=True) + text = text.to(self.device, non_blocking=True) + index = index.numpy() + with torch.no_grad(): + image_feature = self.model.encode_image(image) + text_features = self.model.encode_text(text) + img_buffer[index, :] = image_feature.detach() + text_buffer[index, :] = text_features.detach() + + return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) + + + + + def valid_attack(self,adv_images, texts, adv_labels): + save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") + os.makedirs(save_dir, exist_ok=True) + + + + def test(self, mode_name="i2t"): + self.logger.info("Valid Clean.") + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) + retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) + + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) + mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) + # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) + # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) + self.max_mapt2i = max(self.max_mapt2i, mAPi2t) + self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) + self.logger.info(">>>>>> save all data!") + + + # def valid(self, epoch): + # self.logger.info("Valid.") + # self.change_state(mode="valid") + # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) + # retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) + # # print("get all code") + # mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # # print("map map") + # mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + # mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) + # mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) + # if self.max_mapi2t < mAPi2t: + # self.best_epoch_i = epoch + # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") + # self.max_mapi2t = max(self.max_mapi2t, mAPi2t) + # if self.max_mapt2i < mAPt2i: + # self.best_epoch_t = epoch + # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") + # self.max_mapt2i = max(self.max_mapt2i, mAPt2i) + # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ + # MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") + + def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): + + save_dir = os.path.join(self.args.save_dir, "PR_cruve") + os.makedirs(save_dir, exist_ok=True) + + query_img = query_img.cpu().detach().numpy() + query_txt = query_txt.cpu().detach().numpy() + retrieval_img = retrieval_img.cpu().detach().numpy() + retrieval_txt = retrieval_txt.cpu().detach().numpy() + query_labels = self.query_labels.numpy() + retrieval_labels = self.retrieval_labels.numpy() + + result_dict = { + 'q_img': query_img, + 'q_txt': query_txt, + 'r_img': retrieval_img, + 'r_txt': retrieval_txt, + 'q_l': query_labels, + 'r_l': retrieval_labels + } + scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + self.logger.info(f">>>>>> save best {mode_name} data!") + +