from torch.nn.modules import loss # from model.hash_model import DCMHT as DCMHT import os from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.utils.data as data import scipy.io as scio import numpy as np from .base import TrainBase from torch.nn import functional as F from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils.calc_utils import cal_map, cal_pr from dataset.dataloader import dataloader import clip # from transformers import BertModel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def clamp(delta, clean_imgs): clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) clamp_delta = clamp_imgs - clean_imgs.data return clamp_delta class Trainer(TrainBase): def __init__(self, rank=0): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) text_mean, text_var=self.generate_mapping() self.text_mean=text_mean self.text_var=text_var self.device=rank # self.run() def _init_model(self): self.logger.info("init model.") model_clip, preprocess = clip.load(self.args.victim, device=device) self.model= model_clip self.model.eval() self.model.float() def _init_dataset(self): self.logger.info("init dataset.") self.logger.info(f"Using {self.args.dataset} dataset.") self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, indexFile=self.args.index_file, labelFile=self.args.label_file, maxWords=self.args.max_words, imageResolution=self.args.resolution, query_num=self.args.query_num, train_num=self.args.train_num, seed=self.args.seed) self.train_labels = train_data.get_all_label() self.query_labels = query_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label() self.args.retrieval_num = len(self.retrieval_labels) self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.train_loader = DataLoader( dataset=train_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.query_loader = DataLoader( dataset=query_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.retrieval_loader = DataLoader( dataset=retrieval_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.train_data=train_data def generate_mapping(self): text_train=[] label_train=[] for image, text, label, index in self.train_loader: text=text.to(device, non_blocking=True) # print(self.model.vocab_size) temp_text=self.model.encode_text(text) text_train.append(temp_text.cpu().detach().numpy()) label_train.append(label.detach().numpy()) text_train=np.concatenate(text_train, axis=0) label_train=np.concatenate(label_train, axis=0) label_unipue=np.unique(label_train,axis=0) text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) text_representation = {} text_var_representation = {} for i, centroid in enumerate(label_unipue): text_representation[str(centroid.astype(int))] = text_centroids[i] text_var_representation[str(centroid.astype(int))]= text_var[i] return text_representation, text_var_representation def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): delta = torch.zeros_like(image,requires_grad=True) for i in range(num_iter): self.model.zero_grad() anchor=self.model.encode_image(image+delta) loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity()) negative_dist=(anchor-negetive_mean)**2 / negative_var positive_dist=(anchor-positive_mean)**2 /positive_var negatives=torch.exp(negative_dist / temperature) positives= torch.exp(positive_dist / temperature) loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1 loss.backward(retain_graph=True) delta.data = delta - alpha * delta.grad.detach().sign() delta.data =clamp(delta, image).clamp(-epsilon, epsilon) delta.grad.zero_() adv_code=self.model.encode_image(image+delta) return delta.detach() , adv_code def train_epoch(self): self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") all_loss = 0 times = 0 adv_codes=[] adv_label=[] for image, text, label, index in self.train_loader: self.global_step += 1 times += 1 print(times) image.float() image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negetive_code=self.model.encode_text(text) #targeted sample np.random.seed(times) select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) target_dataset = data.Subset(self.train_data, select_index) target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) _, target_text, target_label, _ = next(iter(target_subset)) target_text=target_text.to(self.rank, non_blocking=True) positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) positive_code=self.model.encode_text(target_text) delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy()) adv_img=np.concatenate(adv_codes) adv_labels=np.concatenate(adv_label) _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_labels = self.retrieval_labels.numpy() mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) # pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels) pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels) self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, 'r_txt': retrieval_txt, 'adv_l': adv_labels, 'r_l': retrieval_labels, # 'q_l':query_labels # 'pr': pr, 'pr_t': pr_t } scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def train(self): self.logger.info("Start train.") for epoch in range(self.args.epochs): self.train_epoch(epoch) self.valid(epoch) self.save_model(epoch) self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") def make_hash_code(self, code: list) -> torch.Tensor: code = torch.stack(code) # print(code.shape) code = code.permute(1, 0, 2) hash_code = torch.argmax(code, dim=-1) hash_code[torch.where(hash_code == 0)] = -1 hash_code = hash_code.float() return hash_code def get_code(self, data_loader, length: int): img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) for image, text, label, index in tqdm(data_loader): image = image.to(self.device, non_blocking=True) text = text.to(self.device, non_blocking=True) index = index.numpy() with torch.no_grad(): image_feature = self.model.encode_image(image) text_features = self.model.encode_text(text) img_buffer[index, :] = image_feature.detach() text_buffer[index, :] = text_features.detach() return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) def valid_attack(self,adv_images, texts, adv_labels): save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") os.makedirs(save_dir, exist_ok=True) def test(self, mode_name="i2t"): self.logger.info("Valid Clean.") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels) pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels) self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels, 'pr_i2t': pr_i2t, 'pr_t2i': pr_t2i } scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) self.logger.info(f">>>>>> save best {mode_name} data!")