From ca75f34880d74cc26b811ea8b6c129131fc5c3ba Mon Sep 17 00:00:00 2001 From: Li Wenyun Date: Thu, 27 Jun 2024 18:01:28 +0800 Subject: [PATCH] text attack --- train/text_train.py | 397 ++++++++++++++++++++++++++++++++------------ utils/get_args.py | 2 +- 2 files changed, 288 insertions(+), 111 deletions(-) diff --git a/train/text_train.py b/train/text_train.py index 5d6211b..8c6f9a8 100644 --- a/train/text_train.py +++ b/train/text_train.py @@ -51,67 +51,98 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·'] filter_words = set(filter_words) -def get_bpe_substitues(substitutes, tokenizer, mlm_model): - # substitutes L, k - # device = mlm_model.device - substitutes = substitutes[0:12, 0:4] # maximum BPE candidates +class GoalFunctionStatus(object): + SUCCEEDED = 0 # attack succeeded + SEARCHING = 1 # In process of searching for a success + FAILED = 2 # attack failed - # find all possible candidates - all_substitutes = [] - for i in range(substitutes.size(0)): - if len(all_substitutes) == 0: - lev_i = substitutes[i] - all_substitutes = [[int(c)] for c in lev_i] - else: - lev_i = [] - for all_sub in all_substitutes: - for j in substitutes[i]: - lev_i.append(all_sub + [int(j)]) - all_substitutes = lev_i +class GoalFunctionResult(object): + goal_score = 1 - # all substitutes list of list of token-id (all candidates) - c_loss = nn.CrossEntropyLoss(reduction='none') - word_list = [] - # all_substitutes = all_substitutes[:24] - all_substitutes = torch.tensor(all_substitutes) # [ N, L ] - all_substitutes = all_substitutes[:24].to(device) - # print(substitutes.size(), all_substitutes.size()) - N, L = all_substitutes.size() - word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size - ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] - ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N - _, word_list = torch.sort(ppl) - word_list = [all_substitutes[i] for i in word_list] - final_words = [] - for word in word_list: - tokens = [tokenizer._convert_id_to_token(int(i)) for i in word] - text = tokenizer.convert_tokens_to_string(tokens) - final_words.append(text) - return final_words + def __init__(self, text, score=0, similarity=None): + self.status = GoalFunctionStatus.SEARCHING + self.text = text + self.score = score + self.similarity = similarity -def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0): - # substitues L,k - # from this matrix to recover a word - words = [] - sub_len, k = substitutes.size() # sub-len, k + @property + def score(self): + return self.__score - if sub_len == 0: - return words + @score.setter + def score(self, value): + self.__score = value + if value >= self.goal_score: + self.status = GoalFunctionStatus.SUCCEEDED - elif sub_len == 1: - for (i, j) in zip(substitutes[0], substitutes_score[0]): - if threshold != 0 and j < threshold: - break - words.append(tokenizer._convert_id_to_token(int(i))) - else: - if use_bpe == 1: - words = get_bpe_substitues(substitutes, tokenizer, mlm_model) - else: - return words - # - # print(words) - return words + def __eq__(self, __o): + return self.text == __o.text + + def __hash__(self): + return hash(self.text) + +# def get_bpe_substitues(substitutes, tokenizer, mlm_model): +# # substitutes L, k +# # device = mlm_model.device +# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates +# +# # find all possible candidates +# +# all_substitutes = [] +# for i in range(substitutes.size(0)): +# if len(all_substitutes) == 0: +# lev_i = substitutes[i] +# all_substitutes = [[int(c)] for c in lev_i] +# else: +# lev_i = [] +# for all_sub in all_substitutes: +# for j in substitutes[i]: +# lev_i.append(all_sub + [int(j)]) +# all_substitutes = lev_i +# +# # all substitutes list of list of token-id (all candidates) +# c_loss = nn.CrossEntropyLoss(reduction='none') +# word_list = [] +# # all_substitutes = all_substitutes[:24] +# all_substitutes = torch.tensor(all_substitutes) # [ N, L ] +# all_substitutes = all_substitutes[:24].to(device) +# # print(substitutes.size(), all_substitutes.size()) +# N, L = all_substitutes.size() +# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size +# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] +# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N +# _, word_list = torch.sort(ppl) +# word_list = [all_substitutes[i] for i in word_list] +# final_words = [] +# for word in word_list: +# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word] +# text = tokenizer.convert_tokens_to_string(tokens) +# final_words.append(text) +# return final_words + +# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0): +# # substitues L,k +# # from this matrix to recover a word +# words = [] +# sub_len, k = substitutes.size() # sub-len, k +# +# if sub_len == 0: +# return words +# +# elif sub_len == 1: +# for (i, j) in zip(substitutes[0], substitutes_score[0]): +# if threshold != 0 and j < threshold: +# break +# words.append(tokenizer._convert_id_to_token(int(i))) +# else: +# if use_bpe == 1: +# words = get_bpe_substitues(substitutes, tokenizer, mlm_model) +# else: +# return words +# # +# # print(words) +# return words class Trainer(TrainBase): @@ -127,6 +158,7 @@ class Trainer(TrainBase): self.clip_tokenizer=Tokenizer() self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True) self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder) + self.attack_thred = self.args.attack_thred # self.run() def _init_model(self): @@ -220,8 +252,117 @@ class Trainer(TrainBase): masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:]) # list of words return masked_words - - + + def get_transformations(self, text, idx, substitutes): + words = text.split(' ') + + trans_text = [] + for sub in substitutes: + words[idx] = sub + trans_text.append(' '.join(words)) + return trans_text + + def get_word_predictions(self, text): + _, _, keys = self._tokenize(text) + + inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len, + truncation=True, return_tensors="pt") + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs['attention_mask'] + with torch.no_grad(): + word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size) + + word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1) + + word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP] + word_pred_scores_all = word_pred_scores_all[1:-1, :] + + return keys, word_predictions, word_pred_scores_all, attention_mask + + def get_bpe_substitutes(self, substitutes): + # substitutes L, k + substitutes = substitutes[0:12, 0:4] # maximum BPE candidates + + # find all possible candidates + all_substitutes = [] + for i in range(substitutes.size(0)): + if len(all_substitutes) == 0: + lev_i = substitutes[i] + all_substitutes = [[int(c)] for c in lev_i] + else: + lev_i = [] + for all_sub in all_substitutes: + for j in substitutes[i]: + lev_i.append(all_sub + [int(j)]) + all_substitutes = lev_i + + # all substitutes: list of list of token-id (all candidates) + cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') + + word_list = [] + all_substitutes = torch.tensor(all_substitutes) # [ N, L ] + all_substitutes = all_substitutes[:24].to(self.device) + + N, L = all_substitutes.size() + word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size + ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ] + ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N + + _, word_list = torch.sort(ppl) + word_list = [all_substitutes[i] for i in word_list] + final_words = [] + for word in word_list: + tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word] + text = ' '.join([t.strip() for t in tokens]) + final_words.append(text) + return final_words + + def get_substitutes(self, substitutes, substitutes_score, threshold=3.0): + ret = [] + num_sub, _ = substitutes.size() + if num_sub == 0: + ret = [] + elif num_sub == 1: + for id, score in zip(substitutes[0], substitutes_score[0]): + if threshold != 0 and score < threshold: + break + ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id))) + elif self.args.enable_bpe: + ret = self.get_bpe_substitutes(substitutes) + return ret + + def filter_substitutes(self, substitues): + + ret = [] + for word in substitues: + if word.lower() in filter_words: + continue + if '##' in word: + continue + + ret.append(word) + return ret + + def get_goal_results(self, trans_texts, ori_text, ref_text): + text_batch = [ref_text] + trans_texts + with torch.no_grad(): + feats = self.feature_extractor(text_batch, device=self.device) + feats = feats.flatten(start_dim=1) + + ref_feats = feats[0].unsqueeze(0) + trans_feats = feats[1:] + cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats) + cos_sim = cos_sim.cpu().numpy() + + text_sim = self.get_text_similarity(trans_texts, ori_text) + + results = [] + for i in range(len(trans_texts)): + if text_sim[i] is not None and text_sim[i] < self.semantic_thred: + continue + results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i])) + return results + def generate_mapping(self): image_train=[] label_train=[] @@ -246,63 +387,99 @@ class Trainer(TrainBase): return image_representation, image_var_representation - def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, - beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): + # def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, + # beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05): + def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, + positive_var, + beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05): - bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt') - mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits - word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1) + # bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt') + keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text) #clean state - clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) - final_adverse = [] - - # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) - # print(texts) - for i, text in enumerate(texts): - important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words) - list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) - words, sub_words, keys = self._tokenize(text) - final_words = copy.deepcopy(words) - change = 0 - for top_index in list_of_index: - if change >= self.args.num_perturbation: - break - tgt_word = words[top_index[0]] - if tgt_word in filter_words: - continue - if keys[top_index[0]][0] > self.args.max_length - 2: - continue - substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k - word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]] - substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores, - self.args.threshold_pred_score) - replace_texts = [' '.join(final_words)] - available_substitutes = [tgt_word] - for substitute_ in substitutes: - substitute = substitute_ - if substitute == tgt_word: - continue # filter out original word - if '##' in substitute: - continue # filter out sub-word + # clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) + cur_result = GoalFunctionResult(raw_text) + mask_idx = np.where(mask.cpu().numpy() == 1)[0] - if substitute in filter_words: - continue - temp_replace = copy.deepcopy(final_words) - temp_replace[top_index[0]] = substitute - available_substitutes.append(substitute) - replace_texts.append(' '.join(temp_replace)) - replace_text_input = self.clip_tokenizer(replace_texts).to(device) - replace_embeds = self.model.encode_text(replace_text_input) - - loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var) - loss = loss.sum(dim=-1) - candidate_idx = loss.argmax() - final_words[top_index[0]] = available_substitutes[candidate_idx] - if available_substitutes[candidate_idx] != tgt_word: - change += 1 - final_adverse.append(' '.join(final_words)) - return final_adverse + + for idx in mask_idx: + predictions = word_predictions[keys[idx][0]: keys[idx][1]] + predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]] + substitutes = self.get_substitutes(predictions, predictions_socre) + substitutes = self.filter_substitutes(substitutes) + trans_texts = self.get_transformations(raw_text, idx, substitutes) + if len(trans_texts) == 0: + continue + # loss function + results = self.get_goal_results(trans_texts, raw_text, positive_code) + results = sorted(results, key=lambda x: x.score, reverse=True) + + if len(results) > 0 and results[0].score > cur_result.score: + cur_result = results[0] + else: + continue + + if cur_result.status == GoalFunctionStatus.SUCCEEDED: + max_similarity = cur_result.similarity + if max_similarity is None: + # similarity is not calculated + continue + + for result in results[1:]: + if result.status != GoalFunctionStatus.SUCCEEDED: + break + if result.similarity > max_similarity: + max_similarity = result.similarity + cur_result = result + return cur_result + if cur_result.status == GoalFunctionStatus.SEARCHING: + cur_result.status = GoalFunctionStatus.FAILED + return cur_result + + + # important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words) + # list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True) + # words, sub_words, keys = self._tokenize(text) + # final_words = copy.deepcopy(words) + # change = 0 + # for top_index in list_of_index: + # if change >= self.args.num_perturbation: + # break + # tgt_word = words[top_index[0]] + # if tgt_word in filter_words: + # continue + # if keys[top_index[0]][0] > self.args.max_length - 2: + # continue + # substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k + # word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]] + # substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores, + # self.args.threshold_pred_score) + # replace_texts = [' '.join(final_words)] + # available_substitutes = [tgt_word] + # for substitute_ in substitutes: + # substitute = substitute_ + # if substitute == tgt_word: + # continue # filter out original word + # if '##' in substitute: + # continue # filter out sub-word + # + # if substitute in filter_words: + # continue + # temp_replace = copy.deepcopy(final_words) + # temp_replace[top_index[0]] = substitute + # available_substitutes.append(substitute) + # replace_texts.append(' '.join(temp_replace)) + # replace_text_input = self.clip_tokenizer(replace_texts).to(device) + # replace_embeds = self.model.encode_text(replace_text_input) + # + # loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var) + # loss = loss.sum(dim=-1) + # candidate_idx = loss.argmax() + # final_words[top_index[0]] = available_substitutes[candidate_idx] + # if available_substitutes[candidate_idx] != tgt_word: + # change += 1 + # final_adverse.append(' '.join(final_words)) + # return final_adverse def train_epoch(self): # self.change_state(mode="valid") diff --git a/utils/get_args.py b/utils/get_args.py index 36c1518..615e4d9 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument("--lr-decay-freq", type=int, default=5) parser.add_argument("--display-step", type=int, default=50) parser.add_argument("--seed", type=int, default=1814) - + parser.add_argument("--attack-thred", type=float, default=0.05) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--lr-decay", type=float, default=0.9) parser.add_argument("--clip-lr", type=float, default=0.00001)