text attack

This commit is contained in:
Li Wenyun 2024-06-27 18:01:28 +08:00
parent ebc17c80dc
commit ca75f34880
2 changed files with 288 additions and 111 deletions

View File

@ -51,67 +51,98 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words)
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# substitutes L, k
# device = mlm_model.device
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
class GoalFunctionStatus(object):
SUCCEEDED = 0 # attack succeeded
SEARCHING = 1 # In process of searching for a success
FAILED = 2 # attack failed
# find all possible candidates
all_substitutes = []
for i in range(substitutes.size(0)):
if len(all_substitutes) == 0:
lev_i = substitutes[i]
all_substitutes = [[int(c)] for c in lev_i]
else:
lev_i = []
for all_sub in all_substitutes:
for j in substitutes[i]:
lev_i.append(all_sub + [int(j)])
all_substitutes = lev_i
class GoalFunctionResult(object):
goal_score = 1
# all substitutes list of list of token-id (all candidates)
c_loss = nn.CrossEntropyLoss(reduction='none')
word_list = []
# all_substitutes = all_substitutes[:24]
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
all_substitutes = all_substitutes[:24].to(device)
# print(substitutes.size(), all_substitutes.size())
N, L = all_substitutes.size()
word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
_, word_list = torch.sort(ppl)
word_list = [all_substitutes[i] for i in word_list]
final_words = []
for word in word_list:
tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
text = tokenizer.convert_tokens_to_string(tokens)
final_words.append(text)
return final_words
def __init__(self, text, score=0, similarity=None):
self.status = GoalFunctionStatus.SEARCHING
self.text = text
self.score = score
self.similarity = similarity
def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
# substitues L,k
# from this matrix to recover a word
words = []
sub_len, k = substitutes.size() # sub-len, k
@property
def score(self):
return self.__score
if sub_len == 0:
return words
@score.setter
def score(self, value):
self.__score = value
if value >= self.goal_score:
self.status = GoalFunctionStatus.SUCCEEDED
elif sub_len == 1:
for (i, j) in zip(substitutes[0], substitutes_score[0]):
if threshold != 0 and j < threshold:
break
words.append(tokenizer._convert_id_to_token(int(i)))
else:
if use_bpe == 1:
words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
else:
return words
def __eq__(self, __o):
return self.text == __o.text
def __hash__(self):
return hash(self.text)
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# # substitutes L, k
# # device = mlm_model.device
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
#
# print(words)
return words
# # find all possible candidates
#
# all_substitutes = []
# for i in range(substitutes.size(0)):
# if len(all_substitutes) == 0:
# lev_i = substitutes[i]
# all_substitutes = [[int(c)] for c in lev_i]
# else:
# lev_i = []
# for all_sub in all_substitutes:
# for j in substitutes[i]:
# lev_i.append(all_sub + [int(j)])
# all_substitutes = lev_i
#
# # all substitutes list of list of token-id (all candidates)
# c_loss = nn.CrossEntropyLoss(reduction='none')
# word_list = []
# # all_substitutes = all_substitutes[:24]
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
# all_substitutes = all_substitutes[:24].to(device)
# # print(substitutes.size(), all_substitutes.size())
# N, L = all_substitutes.size()
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
# _, word_list = torch.sort(ppl)
# word_list = [all_substitutes[i] for i in word_list]
# final_words = []
# for word in word_list:
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
# text = tokenizer.convert_tokens_to_string(tokens)
# final_words.append(text)
# return final_words
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
# # substitues L,k
# # from this matrix to recover a word
# words = []
# sub_len, k = substitutes.size() # sub-len, k
#
# if sub_len == 0:
# return words
#
# elif sub_len == 1:
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
# if threshold != 0 and j < threshold:
# break
# words.append(tokenizer._convert_id_to_token(int(i)))
# else:
# if use_bpe == 1:
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
# else:
# return words
# #
# # print(words)
# return words
class Trainer(TrainBase):
@ -127,6 +158,7 @@ class Trainer(TrainBase):
self.clip_tokenizer=Tokenizer()
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
self.attack_thred = self.args.attack_thred
# self.run()
def _init_model(self):
@ -221,6 +253,115 @@ class Trainer(TrainBase):
# list of words
return masked_words
def get_transformations(self, text, idx, substitutes):
words = text.split(' ')
trans_text = []
for sub in substitutes:
words[idx] = sub
trans_text.append(' '.join(words))
return trans_text
def get_word_predictions(self, text):
_, _, keys = self._tokenize(text)
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len,
truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs['attention_mask']
with torch.no_grad():
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size)
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1)
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
word_pred_scores_all = word_pred_scores_all[1:-1, :]
return keys, word_predictions, word_pred_scores_all, attention_mask
def get_bpe_substitutes(self, substitutes):
# substitutes L, k
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
# find all possible candidates
all_substitutes = []
for i in range(substitutes.size(0)):
if len(all_substitutes) == 0:
lev_i = substitutes[i]
all_substitutes = [[int(c)] for c in lev_i]
else:
lev_i = []
for all_sub in all_substitutes:
for j in substitutes[i]:
lev_i.append(all_sub + [int(j)])
all_substitutes = lev_i
# all substitutes: list of list of token-id (all candidates)
cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
word_list = []
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
all_substitutes = all_substitutes[:24].to(self.device)
N, L = all_substitutes.size()
word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size
ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
_, word_list = torch.sort(ppl)
word_list = [all_substitutes[i] for i in word_list]
final_words = []
for word in word_list:
tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word]
text = ' '.join([t.strip() for t in tokens])
final_words.append(text)
return final_words
def get_substitutes(self, substitutes, substitutes_score, threshold=3.0):
ret = []
num_sub, _ = substitutes.size()
if num_sub == 0:
ret = []
elif num_sub == 1:
for id, score in zip(substitutes[0], substitutes_score[0]):
if threshold != 0 and score < threshold:
break
ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id)))
elif self.args.enable_bpe:
ret = self.get_bpe_substitutes(substitutes)
return ret
def filter_substitutes(self, substitues):
ret = []
for word in substitues:
if word.lower() in filter_words:
continue
if '##' in word:
continue
ret.append(word)
return ret
def get_goal_results(self, trans_texts, ori_text, ref_text):
text_batch = [ref_text] + trans_texts
with torch.no_grad():
feats = self.feature_extractor(text_batch, device=self.device)
feats = feats.flatten(start_dim=1)
ref_feats = feats[0].unsqueeze(0)
trans_feats = feats[1:]
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
cos_sim = cos_sim.cpu().numpy()
text_sim = self.get_text_similarity(trans_texts, ori_text)
results = []
for i in range(len(trans_texts)):
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
continue
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
return results
def generate_mapping(self):
image_train=[]
@ -246,63 +387,99 @@ class Trainer(TrainBase):
return image_representation, image_var_representation
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var,
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
#clean state
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
final_adverse = []
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
cur_result = GoalFunctionResult(raw_text)
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
# print(texts)
for i, text in enumerate(texts):
important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
words, sub_words, keys = self._tokenize(text)
final_words = copy.deepcopy(words)
change = 0
for top_index in list_of_index:
if change >= self.args.num_perturbation:
for idx in mask_idx:
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
substitutes = self.get_substitutes(predictions, predictions_socre)
substitutes = self.filter_substitutes(substitutes)
trans_texts = self.get_transformations(raw_text, idx, substitutes)
if len(trans_texts) == 0:
continue
# loss function
results = self.get_goal_results(trans_texts, raw_text, positive_code)
results = sorted(results, key=lambda x: x.score, reverse=True)
if len(results) > 0 and results[0].score > cur_result.score:
cur_result = results[0]
else:
continue
if cur_result.status == GoalFunctionStatus.SUCCEEDED:
max_similarity = cur_result.similarity
if max_similarity is None:
# similarity is not calculated
continue
for result in results[1:]:
if result.status != GoalFunctionStatus.SUCCEEDED:
break
tgt_word = words[top_index[0]]
if tgt_word in filter_words:
continue
if keys[top_index[0]][0] > self.args.max_length - 2:
continue
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
self.args.threshold_pred_score)
replace_texts = [' '.join(final_words)]
available_substitutes = [tgt_word]
for substitute_ in substitutes:
substitute = substitute_
if substitute == tgt_word:
continue # filter out original word
if '##' in substitute:
continue # filter out sub-word
if result.similarity > max_similarity:
max_similarity = result.similarity
cur_result = result
return cur_result
if cur_result.status == GoalFunctionStatus.SEARCHING:
cur_result.status = GoalFunctionStatus.FAILED
return cur_result
if substitute in filter_words:
continue
temp_replace = copy.deepcopy(final_words)
temp_replace[top_index[0]] = substitute
available_substitutes.append(substitute)
replace_texts.append(' '.join(temp_replace))
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
replace_embeds = self.model.encode_text(replace_text_input)
loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
loss = loss.sum(dim=-1)
candidate_idx = loss.argmax()
final_words[top_index[0]] = available_substitutes[candidate_idx]
if available_substitutes[candidate_idx] != tgt_word:
change += 1
final_adverse.append(' '.join(final_words))
return final_adverse
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
# words, sub_words, keys = self._tokenize(text)
# final_words = copy.deepcopy(words)
# change = 0
# for top_index in list_of_index:
# if change >= self.args.num_perturbation:
# break
# tgt_word = words[top_index[0]]
# if tgt_word in filter_words:
# continue
# if keys[top_index[0]][0] > self.args.max_length - 2:
# continue
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
# self.args.threshold_pred_score)
# replace_texts = [' '.join(final_words)]
# available_substitutes = [tgt_word]
# for substitute_ in substitutes:
# substitute = substitute_
# if substitute == tgt_word:
# continue # filter out original word
# if '##' in substitute:
# continue # filter out sub-word
#
# if substitute in filter_words:
# continue
# temp_replace = copy.deepcopy(final_words)
# temp_replace[top_index[0]] = substitute
# available_substitutes.append(substitute)
# replace_texts.append(' '.join(temp_replace))
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
# replace_embeds = self.model.encode_text(replace_text_input)
#
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
# loss = loss.sum(dim=-1)
# candidate_idx = loss.argmax()
# final_words[top_index[0]] = available_substitutes[candidate_idx]
# if available_substitutes[candidate_idx] != tgt_word:
# change += 1
# final_adverse.append(' '.join(final_words))
# return final_adverse
def train_epoch(self):
# self.change_state(mode="valid")

View File

@ -31,7 +31,7 @@ def get_args():
parser.add_argument("--lr-decay-freq", type=int, default=5)
parser.add_argument("--display-step", type=int, default=50)
parser.add_argument("--seed", type=int, default=1814)
parser.add_argument("--attack-thred", type=float, default=0.05)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--lr-decay", type=float, default=0.9)
parser.add_argument("--clip-lr", type=float, default=0.00001)