From 321817c86fb89d87576a8e55c70853e7b3f9a100 Mon Sep 17 00:00:00 2001 From: Li Wenyun Date: Thu, 27 Jun 2024 21:44:48 +0800 Subject: [PATCH] new text attack --- train/hash_train.py | 23 ------------ train/text_train.py | 87 ++++++++++----------------------------------- 2 files changed, 18 insertions(+), 92 deletions(-) diff --git a/train/hash_train.py b/train/hash_train.py index 6ed4d16..d00ac4e 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -116,8 +116,6 @@ class Trainer(TrainBase): beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): delta = torch.zeros_like(image,requires_grad=True) - # one=torch.zeros_like(positive) - # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) for i in range(num_iter): self.model.zero_grad() anchor=self.model.encode_image(image+delta) @@ -287,27 +285,6 @@ class Trainer(TrainBase): self.logger.info(">>>>>> save all data!") - # def valid(self, epoch): - # self.logger.info("Valid.") - # self.change_state(mode="valid") - # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) - # retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) - # # print("get all code") - # mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # # print("map map") - # mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - # mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) - # mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) - # if self.max_mapi2t < mAPi2t: - # self.best_epoch_i = epoch - # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") - # self.max_mapi2t = max(self.max_mapi2t, mAPi2t) - # if self.max_mapt2i < mAPt2i: - # self.best_epoch_t = epoch - # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") - # self.max_mapt2i = max(self.max_mapt2i, mAPt2i) - # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ - # MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): diff --git a/train/text_train.py b/train/text_train.py index 8c6f9a8..4b689ce 100644 --- a/train/text_train.py +++ b/train/text_train.py @@ -343,25 +343,19 @@ class Trainer(TrainBase): ret.append(word) return ret - def get_goal_results(self, trans_texts, ori_text, ref_text): - text_batch = [ref_text] + trans_texts - with torch.no_grad(): - feats = self.feature_extractor(text_batch, device=self.device) - feats = feats.flatten(start_dim=1) + def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, + positive_var, beta=10,temperature=0.05): + trans_feature= self.clip_tokenizer(trans_texts) + anchor=self.model.encode_text(trans_feature) + loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code, + distance_function=nn.CosineSimilarity(), reduction='none') + negative_dist = (anchor - negetive_mean) ** 2 / negative_var + positive_dist = (anchor - positive_mean) ** 2 / positive_var + negatives = torch.exp(negative_dist / temperature) + positives = torch.exp(positive_dist / temperature) + loss = torch.log(positives / (positives + negatives)) + beta * loss1 + return loss - ref_feats = feats[0].unsqueeze(0) - trans_feats = feats[1:] - cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats) - cos_sim = cos_sim.cpu().numpy() - - text_sim = self.get_text_similarity(trans_texts, ori_text) - - results = [] - for i in range(len(trans_texts)): - if text_sim[i] is not None and text_sim[i] < self.semantic_thred: - continue - results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i])) - return results def generate_mapping(self): image_train=[] @@ -387,13 +381,9 @@ class Trainer(TrainBase): return image_representation, image_var_representation - # def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, - # beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): - def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, - positive_var, - beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05): - # bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt') + def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, + positive_var, beta=10, temperature=0.05): keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text) #clean state @@ -411,7 +401,8 @@ class Trainer(TrainBase): if len(trans_texts) == 0: continue # loss function - results = self.get_goal_results(trans_texts, raw_text, positive_code) + results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code, + positive_mean,positive_var, beta,temperature) results = sorted(results, key=lambda x: x.score, reverse=True) if len(results) > 0 and results[0].score > cur_result.score: @@ -437,49 +428,7 @@ class Trainer(TrainBase): return cur_result - # important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words) - # list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) - # words, sub_words, keys = self._tokenize(text) - # final_words = copy.deepcopy(words) - # change = 0 - # for top_index in list_of_index: - # if change >= self.args.num_perturbation: - # break - # tgt_word = words[top_index[0]] - # if tgt_word in filter_words: - # continue - # if keys[top_index[0]][0] > self.args.max_length - 2: - # continue - # substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k - # word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]] - # substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores, - # self.args.threshold_pred_score) - # replace_texts = [' '.join(final_words)] - # available_substitutes = [tgt_word] - # for substitute_ in substitutes: - # substitute = substitute_ - # if substitute == tgt_word: - # continue # filter out original word - # if '##' in substitute: - # continue # filter out sub-word - # - # if substitute in filter_words: - # continue - # temp_replace = copy.deepcopy(final_words) - # temp_replace[top_index[0]] = substitute - # available_substitutes.append(substitute) - # replace_texts.append(' '.join(temp_replace)) - # replace_text_input = self.clip_tokenizer(replace_texts).to(device) - # replace_embeds = self.model.encode_text(replace_text_input) - # - # loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var) - # loss = loss.sum(dim=-1) - # candidate_idx = loss.argmax() - # final_words[top_index[0]] = available_substitutes[candidate_idx] - # if available_substitutes[candidate_idx] != tgt_word: - # change += 1 - # final_adverse.append(' '.join(final_words)) - # return final_adverse + def train_epoch(self): # self.change_state(mode="valid") @@ -519,7 +468,7 @@ class Trainer(TrainBase): final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var) - final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True) + final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True) adv_code=self.model.encode_text(final_text) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy())