import copy from torch.nn.modules import loss # from model.hash_model import DCMHT as DCMHT import os from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.utils.data as data import scipy.io as scio import numpy as np from .base import TrainBase from torch.nn import functional as F from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils.calc_utils import cal_map, cal_pr from dataset.dataloader import dataloader import clip from model.simple_tokenizer import SimpleTokenizer as Tokenizer from transformers import BertForMaskedLM from model.bert_tokenizer import BertTokenizer # from transformers import BertModel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost', 'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another', 'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as', 'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn', "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere', 'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for', 'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's", 'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn', "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself', 'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none', 'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per', 'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow', 'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too', 'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't", 'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", 'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!', '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] filter_words = set(filter_words) def get_bpe_substitues(substitutes, tokenizer, mlm_model): # substitutes L, k # device = mlm_model.device substitutes = substitutes[0:12, 0:4] # maximum BPE candidates # find all possible candidates all_substitutes = [] for i in range(substitutes.size(0)): if len(all_substitutes) == 0: lev_i = substitutes[i] all_substitutes = [[int(c)] for c in lev_i] else: lev_i = [] for all_sub in all_substitutes: for j in substitutes[i]: lev_i.append(all_sub + [int(j)]) all_substitutes = lev_i # all substitutes list of list of token-id (all candidates) c_loss = nn.CrossEntropyLoss(reduction='none') word_list = [] # all_substitutes = all_substitutes[:24] all_substitutes = torch.tensor(all_substitutes) # [ N, L ] all_substitutes = all_substitutes[:24].to(device) # print(substitutes.size(), all_substitutes.size()) N, L = all_substitutes.size() word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N _, word_list = torch.sort(ppl) word_list = [all_substitutes[i] for i in word_list] final_words = [] for word in word_list: tokens = [tokenizer._convert_id_to_token(int(i)) for i in word] text = tokenizer.convert_tokens_to_string(tokens) final_words.append(text) return final_words def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0): # substitues L,k # from this matrix to recover a word words = [] sub_len, k = substitutes.size() # sub-len, k if sub_len == 0: return words elif sub_len == 1: for (i, j) in zip(substitutes[0], substitutes_score[0]): if threshold != 0 and j < threshold: break words.append(tokenizer._convert_id_to_token(int(i))) else: if use_bpe == 1: words = get_bpe_substitues(substitutes, tokenizer, mlm_model) else: return words # # print(words) return words class Trainer(TrainBase): def __init__(self, rank=0): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) image_mean, image_var=self.generate_mapping() self.image_mean=image_mean self.image_var=image_var self.device=rank self.clip_tokenizer=Tokenizer() self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True) self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder) # self.run() def _init_model(self): self.logger.info("init model.") model_clip, preprocess = clip.load(self.args.victim, device=device) self.model= model_clip self.model.eval() self.model.float() def _init_dataset(self): self.logger.info("init dataset.") self.logger.info(f"Using {self.args.dataset} dataset.") self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, indexFile=self.args.index_file, labelFile=self.args.label_file, maxWords=self.args.max_words, imageResolution=self.args.resolution, query_num=self.args.query_num, train_num=self.args.train_num, seed=self.args.seed) self.train_labels = train_data.get_all_label() self.query_labels = query_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label() self.args.retrieval_num = len(self.retrieval_labels) self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.train_loader = DataLoader( dataset=train_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.query_loader = DataLoader( dataset=query_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.retrieval_loader = DataLoader( dataset=retrieval_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.train_data=train_data def _tokenize(self, text): words = text.split(' ') sub_words = [] keys = [] index = 0 for word in words: sub = self.bert_tokenizer.tokenize(word) sub_words += sub keys.append([index, index + len(sub)]) index += len(sub) return words, sub_words, keys def get_important_scores(self, text, origin_embeds, batch_size, max_length): # device = origin_embeds.device masked_words = self._get_masked(text) masked_texts = [' '.join(words) for words in masked_words] # list of text of masked words masked_embeds = [] for i in range(0, len(masked_texts), batch_size): masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device) masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask) masked_embeds.append(masked_embed) masked_embeds = torch.cat(masked_embeds, dim=0) criterion = torch.nn.KLDivLoss(reduction='none') import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1)) return import_scores.sum(dim=-1) def _get_masked(self, text): words = text.split(' ') len_text = len(words) masked_words = [] for i in range(len_text): masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:]) # list of words return masked_words def generate_mapping(self): image_train=[] label_train=[] for image, text, label, index in self.train_loader: # raw_text=[self.clip_tokenizer.decode(token) for token in text] image=image.to(device, non_blocking=True) # print(self.model.vocab_size) temp_image=self.model.encode_image(image) image_train.append(temp_image.cpu().detach().numpy()) label_train.append(label.detach().numpy()) image_train=np.concatenate(image_train, axis=0) label_train=np.concatenate(label_train, axis=0) label_unipue=np.unique(label_train,axis=0) image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) image_representation = {} image_var_representation = {} for i, centroid in enumerate(label_unipue): image_representation[str(centroid.astype(int))] = image_centroids[i] image_var_representation[str(centroid.astype(int))]= image_var[i] return image_representation, image_var_representation def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt') mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1) #clean state clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) final_adverse = [] # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) # print(texts) for i, text in enumerate(texts): important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words) list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) words, sub_words, keys = self._tokenize(text) final_words = copy.deepcopy(words) change = 0 for top_index in list_of_index: if change >= self.args.num_perturbation: break tgt_word = words[top_index[0]] if tgt_word in filter_words: continue if keys[top_index[0]][0] > self.args.max_length - 2: continue substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]] substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores, self.args.threshold_pred_score) replace_texts = [' '.join(final_words)] available_substitutes = [tgt_word] for substitute_ in substitutes: substitute = substitute_ if substitute == tgt_word: continue # filter out original word if '##' in substitute: continue # filter out sub-word if substitute in filter_words: continue temp_replace = copy.deepcopy(final_words) temp_replace[top_index[0]] = substitute available_substitutes.append(substitute) replace_texts.append(' '.join(temp_replace)) replace_text_input = self.clip_tokenizer(replace_texts).to(device) replace_embeds = self.model.encode_text(replace_text_input) loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var) loss = loss.sum(dim=-1) candidate_idx = loss.argmax() final_words[top_index[0]] = available_substitutes[candidate_idx] if available_substitutes[candidate_idx] != tgt_word: change += 1 final_adverse.append(' '.join(final_words)) return final_adverse def train_epoch(self): # self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") all_loss = 0 times = 0 adv_codes=[] adv_label=[] for image, text, label, index in self.train_loader: self.global_step += 1 times += 1 print(times) image.float() raw_text=[self.clip_tokenizer.my_decode(token) for token in text] image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negetive_code=self.model.encode_image(image) #targeted sample np.random.seed(times) select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) target_dataset = data.Subset(self.train_data, select_index) target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) target_image, _, target_label, _ = next(iter(target_subset)) target_image=target_image.to(self.rank, non_blocking=True) positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) positive_code=self.model.encode_image(target_image) final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var) final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True) adv_code=self.model.encode_text(final_text) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy()) adv_img=np.concatenate(adv_codes) adv_labels=np.concatenate(adv_label) _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_labels = self.retrieval_labels.numpy() mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, 'r_txt': retrieval_txt, 'adv_l': adv_labels, 'r_l': retrieval_labels # 'q_l':query_labels # 'pr': pr, # 'pr_t': pr_t } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def train(self): self.logger.info("Start train.") for epoch in range(self.args.epochs): self.train_epoch(epoch) self.valid(epoch) self.save_model(epoch) self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") def make_hash_code(self, code: list) -> torch.Tensor: code = torch.stack(code) # print(code.shape) code = code.permute(1, 0, 2) hash_code = torch.argmax(code, dim=-1) hash_code[torch.where(hash_code == 0)] = -1 hash_code = hash_code.float() return hash_code def get_code(self, data_loader, length: int): img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) for image, text, label, index in tqdm(data_loader): image = image.to(self.device, non_blocking=True) text = text.to(self.device, non_blocking=True) index = index.numpy() with torch.no_grad(): image_feature = self.model.encode_image(image) text_features = self.model.encode_text(text) img_buffer[index, :] = image_feature.detach() text_buffer[index, :] = text_features.detach() return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) def valid_attack(self,adv_images, texts, adv_labels): save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") os.makedirs(save_dir, exist_ok=True) def test(self, mode_name="i2t"): self.logger.info("Valid Clean.") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") # def valid(self, epoch): # self.logger.info("Valid.") # self.change_state(mode="valid") # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) # retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) # # print("get all code") # mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # # print("map map") # mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) # mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) # mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # if self.max_mapi2t < mAPi2t: # self.best_epoch_i = epoch # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") # self.max_mapi2t = max(self.max_mapi2t, mAPi2t) # if self.max_mapt2i < mAPt2i: # self.best_epoch_t = epoch # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") # self.max_mapt2i = max(self.max_mapt2i, mAPt2i) # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ # MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) self.logger.info(f">>>>>> save best {mode_name} data!")