new text attack
This commit is contained in:
parent
ca75f34880
commit
321817c86f
|
|
@ -116,8 +116,6 @@ class Trainer(TrainBase):
|
|||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
# one=torch.zeros_like(positive)
|
||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
|
|
@ -287,27 +285,6 @@ class Trainer(TrainBase):
|
|||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
# def valid(self, epoch):
|
||||
# self.logger.info("Valid.")
|
||||
# self.change_state(mode="valid")
|
||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# # print("get all code")
|
||||
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# # print("map map")
|
||||
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# if self.max_mapi2t < mAPi2t:
|
||||
# self.best_epoch_i = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
# if self.max_mapt2i < mAPt2i:
|
||||
# self.best_epoch_t = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
|
||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||
|
||||
|
|
|
|||
|
|
@ -343,25 +343,19 @@ class Trainer(TrainBase):
|
|||
ret.append(word)
|
||||
return ret
|
||||
|
||||
def get_goal_results(self, trans_texts, ori_text, ref_text):
|
||||
text_batch = [ref_text] + trans_texts
|
||||
with torch.no_grad():
|
||||
feats = self.feature_extractor(text_batch, device=self.device)
|
||||
feats = feats.flatten(start_dim=1)
|
||||
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||
positive_var, beta=10,temperature=0.05):
|
||||
trans_feature= self.clip_tokenizer(trans_texts)
|
||||
anchor=self.model.encode_text(trans_feature)
|
||||
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code,
|
||||
distance_function=nn.CosineSimilarity(), reduction='none')
|
||||
negative_dist = (anchor - negetive_mean) ** 2 / negative_var
|
||||
positive_dist = (anchor - positive_mean) ** 2 / positive_var
|
||||
negatives = torch.exp(negative_dist / temperature)
|
||||
positives = torch.exp(positive_dist / temperature)
|
||||
loss = torch.log(positives / (positives + negatives)) + beta * loss1
|
||||
return loss
|
||||
|
||||
ref_feats = feats[0].unsqueeze(0)
|
||||
trans_feats = feats[1:]
|
||||
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
|
||||
cos_sim = cos_sim.cpu().numpy()
|
||||
|
||||
text_sim = self.get_text_similarity(trans_texts, ori_text)
|
||||
|
||||
results = []
|
||||
for i in range(len(trans_texts)):
|
||||
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
|
||||
continue
|
||||
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
|
||||
return results
|
||||
|
||||
def generate_mapping(self):
|
||||
image_train=[]
|
||||
|
|
@ -387,13 +381,9 @@ class Trainer(TrainBase):
|
|||
return image_representation, image_var_representation
|
||||
|
||||
|
||||
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||
positive_var,
|
||||
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
|
||||
|
||||
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
||||
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||
positive_var, beta=10, temperature=0.05):
|
||||
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
||||
|
||||
#clean state
|
||||
|
|
@ -411,7 +401,8 @@ class Trainer(TrainBase):
|
|||
if len(trans_texts) == 0:
|
||||
continue
|
||||
# loss function
|
||||
results = self.get_goal_results(trans_texts, raw_text, positive_code)
|
||||
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
|
||||
positive_mean,positive_var, beta,temperature)
|
||||
results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
if len(results) > 0 and results[0].score > cur_result.score:
|
||||
|
|
@ -437,49 +428,7 @@ class Trainer(TrainBase):
|
|||
return cur_result
|
||||
|
||||
|
||||
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
||||
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
||||
# words, sub_words, keys = self._tokenize(text)
|
||||
# final_words = copy.deepcopy(words)
|
||||
# change = 0
|
||||
# for top_index in list_of_index:
|
||||
# if change >= self.args.num_perturbation:
|
||||
# break
|
||||
# tgt_word = words[top_index[0]]
|
||||
# if tgt_word in filter_words:
|
||||
# continue
|
||||
# if keys[top_index[0]][0] > self.args.max_length - 2:
|
||||
# continue
|
||||
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
||||
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
||||
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
||||
# self.args.threshold_pred_score)
|
||||
# replace_texts = [' '.join(final_words)]
|
||||
# available_substitutes = [tgt_word]
|
||||
# for substitute_ in substitutes:
|
||||
# substitute = substitute_
|
||||
# if substitute == tgt_word:
|
||||
# continue # filter out original word
|
||||
# if '##' in substitute:
|
||||
# continue # filter out sub-word
|
||||
#
|
||||
# if substitute in filter_words:
|
||||
# continue
|
||||
# temp_replace = copy.deepcopy(final_words)
|
||||
# temp_replace[top_index[0]] = substitute
|
||||
# available_substitutes.append(substitute)
|
||||
# replace_texts.append(' '.join(temp_replace))
|
||||
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||
# replace_embeds = self.model.encode_text(replace_text_input)
|
||||
#
|
||||
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
||||
# loss = loss.sum(dim=-1)
|
||||
# candidate_idx = loss.argmax()
|
||||
# final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||
# if available_substitutes[candidate_idx] != tgt_word:
|
||||
# change += 1
|
||||
# final_adverse.append(' '.join(final_words))
|
||||
# return final_adverse
|
||||
|
||||
|
||||
def train_epoch(self):
|
||||
# self.change_state(mode="valid")
|
||||
|
|
@ -519,7 +468,7 @@ class Trainer(TrainBase):
|
|||
|
||||
final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var,
|
||||
positive_code,positive_mean,positive_var)
|
||||
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
|
||||
final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True)
|
||||
adv_code=self.model.encode_text(final_text)
|
||||
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||
adv_label.append(target_label.numpy())
|
||||
|
|
|
|||
Loading…
Reference in New Issue