diff --git a/dataset/base.py b/dataset/base.py index cbf32d5..f22e30f 100644 --- a/dataset/base.py +++ b/dataset/base.py @@ -8,8 +8,8 @@ import torch import random from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from model.simple_tokenizer import SimpleTokenizer as Tokenizer - +# from model.simple_tokenizer import SimpleTokenizer as Tokenizer +from transformers import AutoTokenizer class BaseDataset(Dataset): @@ -19,7 +19,7 @@ class BaseDataset(Dataset): indexs: dict, labels: dict, is_train=True, - tokenizer=Tokenizer(), + tokenizer=AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77, truncation=True), maxWords=32, imageResolution=224, npy=False): diff --git a/dataset/dataloader.py b/dataset/dataloader.py index 2ab4ac7..62b3802 100644 --- a/dataset/dataloader.py +++ b/dataset/dataloader.py @@ -10,7 +10,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N random_index = np.random.permutation(range(len(indexs))) query_index = random_index[: query_num] train_index = random_index[query_num: query_num + train_num] - retrieval_index = random_index[query_num:] + retrieval_index = random_index[query_num:-100000] query_indexs = indexs[query_index] query_captions = captions[query_index] diff --git a/main.py b/main.py index c68339f..7192f94 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,11 @@ -from train.text_train import Trainer - +# from train.text_train import Trainer +from train.hash_train import Trainer if __name__ == "__main__": engine=Trainer() engine.test() - engine.train_epoch() + # engine.train_epoch() # engine.train() diff --git a/model/simple_tokenizer.py b/model/simple_tokenizer.py index e22e0d0..0902036 100755 --- a/model/simple_tokenizer.py +++ b/model/simple_tokenizer.py @@ -131,11 +131,16 @@ class SimpleTokenizer(object): text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text - def my_decode(self, tokens): - tokens=[item for item in tokens if item in self.decoder] - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text + # def my_decode(self, tokens): + # print(tokens) + # tem_token=[] + # for i in tokens: + # if i in self.decoder.keys(): + # tem_token.append(self.decoder[i]) + # print(tem_token) + # text = ''.join(tem_token) + # text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + # return text def tokenize(self, text): tokens = [] @@ -147,3 +152,6 @@ class SimpleTokenizer(object): def convert_tokens_to_ids(self, tokens): return [self.encoder[bpe_token] for bpe_token in tokens] + + def convert_ids_to_tokens(self, ids): + return [self.decoder[id] for id in ids] diff --git a/train/hash_train.py b/train/hash_train.py index d00ac4e..7c16699 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -183,18 +183,18 @@ class Trainer(TrainBase): mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) # pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels) - # pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels) + pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels) self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, 'r_txt': retrieval_txt, 'adv_l': adv_labels, - 'r_l': retrieval_labels + 'r_l': retrieval_labels, # 'q_l':query_labels # 'pr': pr, - # 'pr_t': pr_t + 'pr_t': pr_t } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) + scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") @@ -267,8 +267,8 @@ class Trainer(TrainBase): retrieval_labels = self.retrieval_labels.numpy() mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) - # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) - # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) + pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels) + pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels) self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") result_dict = { @@ -277,11 +277,11 @@ class Trainer(TrainBase): 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, - 'r_l': retrieval_labels - # 'pr_i2t': pr_i2t, - # 'pr_t2i': pr_t2i + 'r_l': retrieval_labels, + 'pr_i2t': pr_i2t, + 'pr_t2i': pr_t2i } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) + scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") diff --git a/train/text_train.py b/train/text_train.py index 4b689ce..7fb3e17 100644 --- a/train/text_train.py +++ b/train/text_train.py @@ -12,14 +12,14 @@ import numpy as np from .base import TrainBase from torch.nn import functional as F -from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices +from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity, find_indices from utils.calc_utils import cal_map, cal_pr from dataset.dataloader import dataloader import clip -from model.simple_tokenizer import SimpleTokenizer as Tokenizer +import re from transformers import BertForMaskedLM from model.bert_tokenizer import BertTokenizer -# from transformers import BertModel +from transformers import AutoTokenizer device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -48,13 +48,21 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", 'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!', - '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] + '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] filter_words = set(filter_words) + +def text_filter(text): + text = re.findall(r"<|startoftext|>(.+)<|endoftext|>", text) + text = text[2] + text = re.sub(r'', ' ', text) + return text + + class GoalFunctionStatus(object): SUCCEEDED = 0 # attack succeeded SEARCHING = 1 # In process of searching for a success - FAILED = 2 # attack failed + FAILED = 2 # attack failed class GoalFunctionResult(object): @@ -82,89 +90,29 @@ class GoalFunctionResult(object): def __hash__(self): return hash(self.text) -# def get_bpe_substitues(substitutes, tokenizer, mlm_model): -# # substitutes L, k -# # device = mlm_model.device -# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates -# -# # find all possible candidates -# -# all_substitutes = [] -# for i in range(substitutes.size(0)): -# if len(all_substitutes) == 0: -# lev_i = substitutes[i] -# all_substitutes = [[int(c)] for c in lev_i] -# else: -# lev_i = [] -# for all_sub in all_substitutes: -# for j in substitutes[i]: -# lev_i.append(all_sub + [int(j)]) -# all_substitutes = lev_i -# -# # all substitutes list of list of token-id (all candidates) -# c_loss = nn.CrossEntropyLoss(reduction='none') -# word_list = [] -# # all_substitutes = all_substitutes[:24] -# all_substitutes = torch.tensor(all_substitutes) # [ N, L ] -# all_substitutes = all_substitutes[:24].to(device) -# # print(substitutes.size(), all_substitutes.size()) -# N, L = all_substitutes.size() -# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size -# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] -# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N -# _, word_list = torch.sort(ppl) -# word_list = [all_substitutes[i] for i in word_list] -# final_words = [] -# for word in word_list: -# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word] -# text = tokenizer.convert_tokens_to_string(tokens) -# final_words.append(text) -# return final_words - -# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0): -# # substitues L,k -# # from this matrix to recover a word -# words = [] -# sub_len, k = substitutes.size() # sub-len, k -# -# if sub_len == 0: -# return words -# -# elif sub_len == 1: -# for (i, j) in zip(substitutes[0], substitutes_score[0]): -# if threshold != 0 and j < threshold: -# break -# words.append(tokenizer._convert_id_to_token(int(i))) -# else: -# if use_bpe == 1: -# words = get_bpe_substitues(substitutes, tokenizer, mlm_model) -# else: -# return words -# # -# # print(words) -# return words class Trainer(TrainBase): def __init__(self, - rank=0): + rank=0): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) - image_mean, image_var=self.generate_mapping() - self.image_mean=image_mean - self.image_var=image_var - self.device=rank - self.clip_tokenizer=Tokenizer() - self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True) - self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder) + image_mean, image_var = self.generate_mapping() + self.image_mean = image_mean + self.image_var = image_var + self.device = rank + self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77, + truncation=True) + self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.text_encoder, do_lower_case=True) + self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder).to(device) self.attack_thred = self.args.attack_thred # self.run() def _init_model(self): self.logger.info("init model.") model_clip, preprocess = clip.load(self.args.victim, device=device) - self.model= model_clip + self.model = model_clip self.model.eval() self.model.float() @@ -174,14 +122,14 @@ class Trainer(TrainBase): self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) - train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, - indexFile=self.args.index_file, - labelFile=self.args.label_file, - maxWords=self.args.max_words, - imageResolution=self.args.resolution, - query_num=self.args.query_num, - train_num=self.args.train_num, - seed=self.args.seed) + train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, + indexFile=self.args.index_file, + labelFile=self.args.label_file, + maxWords=self.args.max_words, + imageResolution=self.args.resolution, + query_num=self.args.query_num, + train_num=self.args.train_num, + seed=self.args.seed) self.train_labels = train_data.get_all_label() self.query_labels = query_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label() @@ -189,28 +137,28 @@ class Trainer(TrainBase): self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.train_loader = DataLoader( - dataset=train_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) + dataset=train_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) self.query_loader = DataLoader( - dataset=query_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) + dataset=query_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) self.retrieval_loader = DataLoader( - dataset=retrieval_data, - batch_size=self.args.batch_size, - num_workers=self.args.num_workers, - pin_memory=True, - shuffle=True - ) - self.train_data=train_data - + dataset=retrieval_data, + batch_size=self.args.batch_size, + num_workers=self.args.num_workers, + pin_memory=True, + shuffle=True + ) + self.train_data = train_data + def _tokenize(self, text): words = text.split(' ') @@ -224,8 +172,8 @@ class Trainer(TrainBase): index += len(sub) return words, sub_words, keys - - def get_important_scores(self, text, origin_embeds, batch_size, max_length): + + def get_important_scores(self, text, origin_embeds, batch_size, max_length): # device = origin_embeds.device masked_words = self._get_masked(text) @@ -233,17 +181,20 @@ class Trainer(TrainBase): masked_embeds = [] for i in range(0, len(masked_texts), batch_size): - masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device) + masked_text_input = self.bert_tokenizer(masked_texts[i:i + batch_size], padding='max_length', + truncation=True, max_length=max_length, return_tensors='pt').to( + device) masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask) masked_embeds.append(masked_embed) masked_embeds = torch.cat(masked_embeds, dim=0) criterion = torch.nn.KLDivLoss(reduction='none') - import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1)) + import_scores = criterion(masked_embeds.log_softmax(dim=-1), + origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1)) return import_scores.sum(dim=-1) - + def _get_masked(self, text): words = text.split(' ') len_text = len(words) @@ -265,14 +216,14 @@ class Trainer(TrainBase): def get_word_predictions(self, text): _, _, keys = self._tokenize(text) - inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len, - truncation=True, return_tensors="pt") + inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.args.max_words, + truncation=True, return_tensors="pt") input_ids = inputs["input_ids"].to(self.device) attention_mask = inputs['attention_mask'] with torch.no_grad(): - word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size) - - word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1) + word_predictions = self.ref_net(input_ids)['logits'].squeeze(0) # (seq_len, vocab_size) + # print(self.ref_net(input_ids)['logits'].shape) + word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.args.max_candidate, -1) word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP] word_pred_scores_all = word_pred_scores_all[1:-1, :] @@ -343,55 +294,65 @@ class Trainer(TrainBase): ret.append(word) return ret - def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, - positive_var, beta=10,temperature=0.05): - trans_feature= self.clip_tokenizer(trans_texts) - anchor=self.model.encode_text(trans_feature) - loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code, + def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, + positive_var, beta=10, temperature=0.05): + # print(trans_texts) + trans_feature = clip.tokenize(trans_texts,context_length=77,truncate=True).to(device) + anchor = self.model.encode_text(trans_feature) + batch_size=anchor.shape[0] + loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code.repeat(batch_size, 1), negetive_code.repeat(batch_size, 1), distance_function=nn.CosineSimilarity(), reduction='none') - negative_dist = (anchor - negetive_mean) ** 2 / negative_var - positive_dist = (anchor - positive_mean) ** 2 / positive_var + sim=F.cosine_similarity(anchor,positive_code.unsqueeze(0), dim=1, eps=1e-8).unsqueeze(1) + negative_dist = (anchor - negetive_mean) ** 2 / (negative_var+ 1e-5) + positive_dist = (anchor - positive_mean) ** 2 / (positive_var+ 1e-5) negatives = torch.exp(negative_dist / temperature) positives = torch.exp(positive_dist / temperature) - loss = torch.log(positives / (positives + negatives)) + beta * loss1 - return loss - + loss = torch.log(positives / (positives + negatives)).mean(dim=1, keepdim=True) + beta * loss1 + results = [] + # print(loss.shape) + for i in range(len(trans_texts)): + if loss[i].shape[0] >1 or sim[i] >self.args.sim_threshold: + continue + results.append(GoalFunctionResult(trans_texts[i], score=loss[i], similarity=sim[i])) + return results def generate_mapping(self): - image_train=[] - label_train=[] + image_train = [] + label_train = [] for image, text, label, index in self.train_loader: # raw_text=[self.clip_tokenizer.decode(token) for token in text] - image=image.to(device, non_blocking=True) + image = image.to(device, non_blocking=True) # print(self.model.vocab_size) - temp_image=self.model.encode_image(image) + temp_image = self.model.encode_image(image) image_train.append(temp_image.cpu().detach().numpy()) label_train.append(label.detach().numpy()) - image_train=np.concatenate(image_train, axis=0) - label_train=np.concatenate(label_train, axis=0) - label_unipue=np.unique(label_train,axis=0) - image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) - image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) - + image_train = np.concatenate(image_train, axis=0) + label_train = np.concatenate(label_train, axis=0) + label_unipue = np.unique(label_train, axis=0) + image_centroids = np.stack( + [image_train[find_indices(label_train, label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], + axis=0) + image_var = np.stack( + [image_train[find_indices(label_train, label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], + axis=0) + image_representation = {} image_var_representation = {} for i, centroid in enumerate(label_unipue): image_representation[str(centroid.astype(int))] = image_centroids[i] - image_var_representation[str(centroid.astype(int))]= image_var[i] - return image_representation, image_var_representation - - + image_var_representation[str(centroid.astype(int))] = image_var[i] + return image_representation, image_var_representation - def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, - positive_var, beta=10, temperature=0.05): + def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, + positive_var, beta=10, temperature=0.05): + # print(raw_text) keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text) - + #clean state # clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) cur_result = GoalFunctionResult(raw_text) mask_idx = np.where(mask.cpu().numpy() == 1)[0] - for idx in mask_idx: predictions = word_predictions[keys[idx][0]: keys[idx][1]] predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]] @@ -402,7 +363,7 @@ class Trainer(TrainBase): continue # loss function results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code, - positive_mean,positive_var, beta,temperature) + positive_mean, positive_var, beta, temperature) results = sorted(results, key=lambda x: x.score, reverse=True) if len(results) > 0 and results[0].score > cur_result.score: @@ -427,63 +388,62 @@ class Trainer(TrainBase): cur_result.status = GoalFunctionStatus.FAILED return cur_result - - - def train_epoch(self): # self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") all_loss = 0 times = 0 - adv_codes=[] - adv_label=[] + adv_codes = [] + adv_label = [] for image, text, label, index in self.train_loader: self.global_step += 1 times += 1 print(times) image.float() - - raw_text=[self.clip_tokenizer.my_decode(token) for token in text] + image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) - negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) - negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) - negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) - negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) - negetive_code=self.model.encode_image(image) - + negetive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negative_var = np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negetive_mean = torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) + negative_var = torch.from_numpy(negative_var).to(self.rank, non_blocking=True) + negetive_code = self.model.encode_image(image) + #targeted sample np.random.seed(times) select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) target_dataset = data.Subset(self.train_data, select_index) target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) target_image, _, target_label, _ = next(iter(target_subset)) - target_image=target_image.to(self.rank, non_blocking=True) - positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) - positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) - positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) - positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) - positive_code=self.model.encode_image(target_image) - - - final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var, - positive_code,positive_mean,positive_var) - final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True) - adv_code=self.model.encode_text(final_text) + target_image = target_image.to(self.rank, non_blocking=True) + positive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_var = np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_mean = torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) + positive_var = torch.from_numpy(positive_var).to(self.rank, non_blocking=True) + positive_code = self.model.encode_image(target_image) + # print(self.clip_tokenizer.my_encode('This day is good!')) + raw_text = [self.clip_tokenizer.convert_ids_to_tokens(token.cpu()) for token in text] + raw_text = [text_filter(self.clip_tokenizer.convert_tokens_to_string(txt)) for txt in raw_text] + final_texts=[] + for i in range(self.args.batch_size): + adv_txt=self.target_adv( raw_text[i], negetive_code[i], negetive_mean[i], negative_var[i], + positive_code[i], positive_mean[i], positive_var[i]) + final_texts.append(adv_txt.text) + # final_adverse = self.target_adv( raw_text, negetive_code, negetive_mean, negative_var, + # positive_code, positive_mean, positive_var) + final_text = clip.tokenize(final_texts,context_length=77,truncate=True).to(self.rank, non_blocking=True) + adv_code = self.model.encode_text(final_text) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy()) - adv_img=np.concatenate(adv_codes) - adv_labels=np.concatenate(adv_label) + adv_img = np.concatenate(adv_codes) + adv_labels = np.concatenate(adv_label) + + _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) - _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) - - - retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_labels = self.retrieval_labels.numpy() - - mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) + mAP_t = cal_map(adv_img, adv_labels, retrieval_txt, retrieval_labels) self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, @@ -494,13 +454,9 @@ class Trainer(TrainBase): # 'pr': pr, # 'pr_t': pr_t } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) + scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), + result_dict) self.logger.info(">>>>>> save all data!") - - - - - def train(self): self.logger.info("Start train.") @@ -510,9 +466,8 @@ class Trainer(TrainBase): self.valid(epoch) self.save_model(epoch) - self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") - - + self.logger.info( + f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") def make_hash_code(self, code: list) -> torch.Tensor: @@ -539,25 +494,19 @@ class Trainer(TrainBase): text_features = self.model.encode_text(text) img_buffer[index, :] = image_feature.detach() text_buffer[index, :] = text_features.detach() - - return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) - - - - - def valid_attack(self,adv_images, texts, adv_labels): + + return img_buffer, text_buffer # img_buffer.to(self.rank), text_buffer.to(self.rank) + + def valid_attack(self, adv_images, texts, adv_labels): save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") os.makedirs(save_dir, exist_ok=True) - - def test(self, mode_name="i2t"): self.logger.info("Valid Clean.") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) - query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) + query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) - query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() @@ -565,8 +514,8 @@ class Trainer(TrainBase): retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() - mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) - mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) + mAPi2t = cal_map(query_img, query_labels, retrieval_txt, retrieval_labels) + mAPt2i = cal_map(query_txt, query_labels, retrieval_img, retrieval_labels) # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) self.max_mapt2i = max(self.max_mapt2i, mAPi2t) @@ -579,31 +528,11 @@ class Trainer(TrainBase): 'q_l': query_labels, 'r_l': retrieval_labels } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) + scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), + result_dict) self.logger.info(">>>>>> save all data!") - # def valid(self, epoch): - # self.logger.info("Valid.") - # self.change_state(mode="valid") - # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) - # retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) - # # print("get all code") - # mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # # print("map map") - # mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - # mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - # mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # if self.max_mapi2t < mAPi2t: - # self.best_epoch_i = epoch - # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") - # self.max_mapi2t = max(self.max_mapi2t, mAPi2t) - # if self.max_mapt2i < mAPt2i: - # self.best_epoch_t = epoch - # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") - # self.max_mapt2i = max(self.max_mapt2i, mAPt2i) - # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ - # MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): @@ -625,7 +554,7 @@ class Trainer(TrainBase): 'q_l': query_labels, 'r_l': retrieval_labels } - scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) + scio.savemat( + os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), + result_dict) self.logger.info(f">>>>>> save best {mode_name} data!") - - diff --git a/utils/get_args.py b/utils/get_args.py index 615e4d9..3456cbb 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -9,20 +9,22 @@ def get_args(): parser.add_argument("--save-dir", type=str, default="./result/64-bit") parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") parser.add_argument("--pretrained", type=str, default="") - parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") + parser.add_argument("--dataset", type=str, default="coco", help="choise from [coco, mirflckr25k, nuswide]") parser.add_argument("--index-file", type=str, default="index.mat") parser.add_argument("--caption-file", type=str, default="caption.mat") parser.add_argument("--label-file", type=str, default="label.mat") parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]") parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]") - parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) + parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) parser.add_argument("--text_encoder", type=str, default="bert-base-uncased") parser.add_argument("--topk", type=int, default=10) parser.add_argument("--num-perturbation", type=int, default=3) parser.add_argument("--txt-dim", type=int, default=1024) - parser.add_argument("--output-dim", type=int, default=512) + parser.add_argument("--output-dim", type=int, default=1024) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--max-words", type=int, default=77) + parser.add_argument("--max-candidate", type=int, default=7) + parser.add_argument("--enable-bpe", type=bool, default=False) parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4)