new text attack

This commit is contained in:
Li Wenyun 2024-06-27 21:44:48 +08:00
parent ca75f34880
commit 321817c86f
2 changed files with 18 additions and 92 deletions

View File

@ -116,8 +116,6 @@ class Trainer(TrainBase):
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
delta = torch.zeros_like(image,requires_grad=True)
# one=torch.zeros_like(positive)
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
@ -287,27 +285,6 @@ class Trainer(TrainBase):
self.logger.info(">>>>>> save all data!")
# def valid(self, epoch):
# self.logger.info("Valid.")
# self.change_state(mode="valid")
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# # print("get all code")
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# # print("map map")
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# if self.max_mapi2t < mAPi2t:
# self.best_epoch_i = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
# if self.max_mapt2i < mAPt2i:
# self.best_epoch_t = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):

View File

@ -343,25 +343,19 @@ class Trainer(TrainBase):
ret.append(word)
return ret
def get_goal_results(self, trans_texts, ori_text, ref_text):
text_batch = [ref_text] + trans_texts
with torch.no_grad():
feats = self.feature_extractor(text_batch, device=self.device)
feats = feats.flatten(start_dim=1)
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var, beta=10,temperature=0.05):
trans_feature= self.clip_tokenizer(trans_texts)
anchor=self.model.encode_text(trans_feature)
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code,
distance_function=nn.CosineSimilarity(), reduction='none')
negative_dist = (anchor - negetive_mean) ** 2 / negative_var
positive_dist = (anchor - positive_mean) ** 2 / positive_var
negatives = torch.exp(negative_dist / temperature)
positives = torch.exp(positive_dist / temperature)
loss = torch.log(positives / (positives + negatives)) + beta * loss1
return loss
ref_feats = feats[0].unsqueeze(0)
trans_feats = feats[1:]
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
cos_sim = cos_sim.cpu().numpy()
text_sim = self.get_text_similarity(trans_texts, ori_text)
results = []
for i in range(len(trans_texts)):
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
continue
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
return results
def generate_mapping(self):
image_train=[]
@ -387,13 +381,9 @@ class Trainer(TrainBase):
return image_representation, image_var_representation
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var,
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var, beta=10, temperature=0.05):
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
#clean state
@ -411,7 +401,8 @@ class Trainer(TrainBase):
if len(trans_texts) == 0:
continue
# loss function
results = self.get_goal_results(trans_texts, raw_text, positive_code)
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
positive_mean,positive_var, beta,temperature)
results = sorted(results, key=lambda x: x.score, reverse=True)
if len(results) > 0 and results[0].score > cur_result.score:
@ -437,49 +428,7 @@ class Trainer(TrainBase):
return cur_result
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
# words, sub_words, keys = self._tokenize(text)
# final_words = copy.deepcopy(words)
# change = 0
# for top_index in list_of_index:
# if change >= self.args.num_perturbation:
# break
# tgt_word = words[top_index[0]]
# if tgt_word in filter_words:
# continue
# if keys[top_index[0]][0] > self.args.max_length - 2:
# continue
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
# self.args.threshold_pred_score)
# replace_texts = [' '.join(final_words)]
# available_substitutes = [tgt_word]
# for substitute_ in substitutes:
# substitute = substitute_
# if substitute == tgt_word:
# continue # filter out original word
# if '##' in substitute:
# continue # filter out sub-word
#
# if substitute in filter_words:
# continue
# temp_replace = copy.deepcopy(final_words)
# temp_replace[top_index[0]] = substitute
# available_substitutes.append(substitute)
# replace_texts.append(' '.join(temp_replace))
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
# replace_embeds = self.model.encode_text(replace_text_input)
#
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
# loss = loss.sum(dim=-1)
# candidate_idx = loss.argmax()
# final_words[top_index[0]] = available_substitutes[candidate_idx]
# if available_substitutes[candidate_idx] != tgt_word:
# change += 1
# final_adverse.append(' '.join(final_words))
# return final_adverse
def train_epoch(self):
# self.change_state(mode="valid")
@ -519,7 +468,7 @@ class Trainer(TrainBase):
final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True)
adv_code=self.model.encode_text(final_text)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())