import copy from torch.nn.modules import loss # from model.hash_model import DCMHT as DCMHT import os from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.utils.data as data import scipy.io as scio import numpy as np from .base import TrainBase from torch.nn import functional as F from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity, find_indices from utils.calc_utils import cal_map, cal_pr from dataset.dataloader import dataloader import clip import re from transformers import BertForMaskedLM from model.bert_tokenizer import BertTokenizer from transformers import AutoTokenizer device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost', 'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another', 'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as', 'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides', 'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn', "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere', 'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for', 'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's", 'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn', "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself', 'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none', 'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per', 'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow', 'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein', 'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too', 'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't", 'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", 'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!', '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] filter_words = set(filter_words) def text_filter(text): text = re.findall(r"<|startoftext|>(.+)<|endoftext|>", text) text = text[2] text = re.sub(r'', ' ', text) return text class GoalFunctionStatus(object): SUCCEEDED = 0 # attack succeeded SEARCHING = 1 # In process of searching for a success FAILED = 2 # attack failed class GoalFunctionResult(object): goal_score = 1 def __init__(self, text, score=0, similarity=None): self.status = GoalFunctionStatus.SEARCHING self.text = text self.score = score self.similarity = similarity @property def score(self): return self.__score @score.setter def score(self, value): self.__score = value if value >= self.goal_score: self.status = GoalFunctionStatus.SUCCEEDED def __eq__(self, __o): return self.text == __o.text def __hash__(self): return hash(self.text) class Trainer(TrainBase): def __init__(self, rank=0): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) image_mean, image_var = self.generate_mapping() self.image_mean = image_mean self.image_var = image_var self.device = rank self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77, truncation=True) self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.text_encoder, do_lower_case=True) self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder).to(device) self.attack_thred = self.args.attack_thred # self.run() def _init_model(self): self.logger.info("init model.") model_clip, preprocess = clip.load(self.args.victim, device=device) self.model = model_clip self.model.eval() self.model.float() def _init_dataset(self): self.logger.info("init dataset.") self.logger.info(f"Using {self.args.dataset} dataset.") self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, indexFile=self.args.index_file, labelFile=self.args.label_file, maxWords=self.args.max_words, imageResolution=self.args.resolution, query_num=self.args.query_num, train_num=self.args.train_num, seed=self.args.seed) self.train_labels = train_data.get_all_label() self.query_labels = query_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label() self.args.retrieval_num = len(self.retrieval_labels) self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.train_loader = DataLoader( dataset=train_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.query_loader = DataLoader( dataset=query_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.retrieval_loader = DataLoader( dataset=retrieval_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.train_data = train_data def _tokenize(self, text): words = text.split(' ') sub_words = [] keys = [] index = 0 for word in words: sub = self.bert_tokenizer.tokenize(word) sub_words += sub keys.append([index, index + len(sub)]) index += len(sub) return words, sub_words, keys def get_important_scores(self, text, origin_embeds, batch_size, max_length): # device = origin_embeds.device masked_words = self._get_masked(text) masked_texts = [' '.join(words) for words in masked_words] # list of text of masked words masked_embeds = [] for i in range(0, len(masked_texts), batch_size): masked_text_input = self.bert_tokenizer(masked_texts[i:i + batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to( device) masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask) masked_embeds.append(masked_embed) masked_embeds = torch.cat(masked_embeds, dim=0) criterion = torch.nn.KLDivLoss(reduction='none') import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1)) return import_scores.sum(dim=-1) def _get_masked(self, text): words = text.split(' ') len_text = len(words) masked_words = [] for i in range(len_text): masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:]) # list of words return masked_words def get_transformations(self, text, idx, substitutes): words = text.split(' ') trans_text = [] for sub in substitutes: words[idx] = sub trans_text.append(' '.join(words)) return trans_text def get_word_predictions(self, text): _, _, keys = self._tokenize(text) inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.args.max_words, truncation=True, return_tensors="pt") input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs['attention_mask'] with torch.no_grad(): word_predictions = self.ref_net(input_ids)['logits'].squeeze(0) # (seq_len, vocab_size) # print(self.ref_net(input_ids)['logits'].shape) word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.args.max_candidate, -1) word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP] word_pred_scores_all = word_pred_scores_all[1:-1, :] return keys, word_predictions, word_pred_scores_all, attention_mask def get_bpe_substitutes(self, substitutes): # substitutes L, k substitutes = substitutes[0:12, 0:4] # maximum BPE candidates # find all possible candidates all_substitutes = [] for i in range(substitutes.size(0)): if len(all_substitutes) == 0: lev_i = substitutes[i] all_substitutes = [[int(c)] for c in lev_i] else: lev_i = [] for all_sub in all_substitutes: for j in substitutes[i]: lev_i.append(all_sub + [int(j)]) all_substitutes = lev_i # all substitutes: list of list of token-id (all candidates) cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') word_list = [] all_substitutes = torch.tensor(all_substitutes) # [ N, L ] all_substitutes = all_substitutes[:24].to(self.device) N, L = all_substitutes.size() word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N _, word_list = torch.sort(ppl) word_list = [all_substitutes[i] for i in word_list] final_words = [] for word in word_list: tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word] text = ' '.join([t.strip() for t in tokens]) final_words.append(text) return final_words def get_substitutes(self, substitutes, substitutes_score, threshold=3.0): ret = [] num_sub, _ = substitutes.size() if num_sub == 0: ret = [] elif num_sub == 1: for id, score in zip(substitutes[0], substitutes_score[0]): if threshold != 0 and score < threshold: break ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id))) elif self.args.enable_bpe: ret = self.get_bpe_substitutes(substitutes) return ret def filter_substitutes(self, substitues): ret = [] for word in substitues: if word.lower() in filter_words: continue if '##' in word: continue ret.append(word) return ret def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, positive_var, beta=10, temperature=0.05): # print(trans_texts) trans_feature = clip.tokenize(trans_texts,context_length=77,truncate=True).to(device) anchor = self.model.encode_text(trans_feature) batch_size=anchor.shape[0] loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code.repeat(batch_size, 1), negetive_code.repeat(batch_size, 1), distance_function=nn.CosineSimilarity(), reduction='none') sim=F.cosine_similarity(anchor,positive_code.unsqueeze(0), dim=1, eps=1e-8).unsqueeze(1) negative_dist = (anchor - negetive_mean) ** 2 / (negative_var+ 1e-5) positive_dist = (anchor - positive_mean) ** 2 / (positive_var+ 1e-5) negatives = torch.exp(negative_dist / temperature) positives = torch.exp(positive_dist / temperature) loss = torch.log(positives / (positives + negatives)).mean(dim=1, keepdim=True) + beta * loss1 results = [] # print(loss.shape) for i in range(len(trans_texts)): if loss[i].shape[0] >1 or sim[i] >self.args.sim_threshold: continue results.append(GoalFunctionResult(trans_texts[i], score=loss[i], similarity=sim[i])) return results def generate_mapping(self): image_train = [] label_train = [] for image, text, label, index in self.train_loader: # raw_text=[self.clip_tokenizer.decode(token) for token in text] image = image.to(device, non_blocking=True) # print(self.model.vocab_size) temp_image = self.model.encode_image(image) image_train.append(temp_image.cpu().detach().numpy()) label_train.append(label.detach().numpy()) image_train = np.concatenate(image_train, axis=0) label_train = np.concatenate(label_train, axis=0) label_unipue = np.unique(label_train, axis=0) image_centroids = np.stack( [image_train[find_indices(label_train, label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) image_var = np.stack( [image_train[find_indices(label_train, label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) image_representation = {} image_var_representation = {} for i, centroid in enumerate(label_unipue): image_representation[str(centroid.astype(int))] = image_centroids[i] image_var_representation[str(centroid.astype(int))] = image_var[i] return image_representation, image_var_representation def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, positive_var, beta=10, temperature=0.05): # print(raw_text) keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text) #clean state # clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) cur_result = GoalFunctionResult(raw_text) mask_idx = np.where(mask.cpu().numpy() == 1)[0] for idx in mask_idx: predictions = word_predictions[keys[idx][0]: keys[idx][1]] predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]] substitutes = self.get_substitutes(predictions, predictions_socre) substitutes = self.filter_substitutes(substitutes) trans_texts = self.get_transformations(raw_text, idx, substitutes) if len(trans_texts) == 0: continue # loss function results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, positive_var, beta, temperature) results = sorted(results, key=lambda x: x.score, reverse=True) if len(results) > 0 and results[0].score > cur_result.score: cur_result = results[0] else: continue if cur_result.status == GoalFunctionStatus.SUCCEEDED: max_similarity = cur_result.similarity if max_similarity is None: # similarity is not calculated continue for result in results[1:]: if result.status != GoalFunctionStatus.SUCCEEDED: break if result.similarity > max_similarity: max_similarity = result.similarity cur_result = result return cur_result if cur_result.status == GoalFunctionStatus.SEARCHING: cur_result.status = GoalFunctionStatus.FAILED return cur_result def train_epoch(self): # self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") all_loss = 0 times = 0 adv_codes = [] adv_label = [] for image, text, label, index in self.train_loader: self.global_step += 1 times += 1 print(times) image.float() image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) negetive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var = np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean = torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negative_var = torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negetive_code = self.model.encode_image(image) #targeted sample np.random.seed(times) select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) target_dataset = data.Subset(self.train_data, select_index) target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) target_image, _, target_label, _ = next(iter(target_subset)) target_image = target_image.to(self.rank, non_blocking=True) positive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_var = np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_mean = torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) positive_var = torch.from_numpy(positive_var).to(self.rank, non_blocking=True) positive_code = self.model.encode_image(target_image) # print(self.clip_tokenizer.my_encode('This day is good!')) raw_text = [self.clip_tokenizer.convert_ids_to_tokens(token.cpu()) for token in text] raw_text = [text_filter(self.clip_tokenizer.convert_tokens_to_string(txt)) for txt in raw_text] final_texts=[] for i in range(self.args.batch_size): adv_txt=self.target_adv( raw_text[i], negetive_code[i], negetive_mean[i], negative_var[i], positive_code[i], positive_mean[i], positive_var[i]) final_texts.append(adv_txt.text) # final_adverse = self.target_adv( raw_text, negetive_code, negetive_mean, negative_var, # positive_code, positive_mean, positive_var) final_text = clip.tokenize(final_texts,context_length=77,truncate=True).to(self.rank, non_blocking=True) adv_code = self.model.encode_text(final_text) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy()) adv_img = np.concatenate(adv_codes) adv_labels = np.concatenate(adv_label) _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_labels = self.retrieval_labels.numpy() mAP_t = cal_map(adv_img, adv_labels, retrieval_txt, retrieval_labels) self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, 'r_txt': retrieval_txt, 'adv_l': adv_labels, 'r_l': retrieval_labels # 'q_l':query_labels # 'pr': pr, # 'pr_t': pr_t } scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def train(self): self.logger.info("Start train.") for epoch in range(self.args.epochs): self.train_epoch(epoch) self.valid(epoch) self.save_model(epoch) self.logger.info( f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") def make_hash_code(self, code: list) -> torch.Tensor: code = torch.stack(code) # print(code.shape) code = code.permute(1, 0, 2) hash_code = torch.argmax(code, dim=-1) hash_code[torch.where(hash_code == 0)] = -1 hash_code = hash_code.float() return hash_code def get_code(self, data_loader, length: int): img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) for image, text, label, index in tqdm(data_loader): image = image.to(self.device, non_blocking=True) text = text.to(self.device, non_blocking=True) index = index.numpy() with torch.no_grad(): image_feature = self.model.encode_image(image) text_features = self.model.encode_text(text) img_buffer[index, :] = image_feature.detach() text_buffer[index, :] = text_features.detach() return img_buffer, text_buffer # img_buffer.to(self.rank), text_buffer.to(self.rank) def valid_attack(self, adv_images, texts, adv_labels): save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") os.makedirs(save_dir, exist_ok=True) def test(self, mode_name="i2t"): self.logger.info("Valid Clean.") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() mAPi2t = cal_map(query_img, query_labels, retrieval_txt, retrieval_labels) mAPt2i = cal_map(query_txt, query_labels, retrieval_img, retrieval_labels) # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat( os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) self.logger.info(f">>>>>> save best {mode_name} data!")