text attack
This commit is contained in:
parent
ebc17c80dc
commit
ca75f34880
|
|
@ -51,67 +51,98 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
|
|||
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
||||
filter_words = set(filter_words)
|
||||
|
||||
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
||||
# substitutes L, k
|
||||
# device = mlm_model.device
|
||||
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||||
class GoalFunctionStatus(object):
|
||||
SUCCEEDED = 0 # attack succeeded
|
||||
SEARCHING = 1 # In process of searching for a success
|
||||
FAILED = 2 # attack failed
|
||||
|
||||
# find all possible candidates
|
||||
|
||||
all_substitutes = []
|
||||
for i in range(substitutes.size(0)):
|
||||
if len(all_substitutes) == 0:
|
||||
lev_i = substitutes[i]
|
||||
all_substitutes = [[int(c)] for c in lev_i]
|
||||
else:
|
||||
lev_i = []
|
||||
for all_sub in all_substitutes:
|
||||
for j in substitutes[i]:
|
||||
lev_i.append(all_sub + [int(j)])
|
||||
all_substitutes = lev_i
|
||||
class GoalFunctionResult(object):
|
||||
goal_score = 1
|
||||
|
||||
# all substitutes list of list of token-id (all candidates)
|
||||
c_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
word_list = []
|
||||
# all_substitutes = all_substitutes[:24]
|
||||
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||||
all_substitutes = all_substitutes[:24].to(device)
|
||||
# print(substitutes.size(), all_substitutes.size())
|
||||
N, L = all_substitutes.size()
|
||||
word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
|
||||
ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||||
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||||
_, word_list = torch.sort(ppl)
|
||||
word_list = [all_substitutes[i] for i in word_list]
|
||||
final_words = []
|
||||
for word in word_list:
|
||||
tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
|
||||
text = tokenizer.convert_tokens_to_string(tokens)
|
||||
final_words.append(text)
|
||||
return final_words
|
||||
def __init__(self, text, score=0, similarity=None):
|
||||
self.status = GoalFunctionStatus.SEARCHING
|
||||
self.text = text
|
||||
self.score = score
|
||||
self.similarity = similarity
|
||||
|
||||
def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
|
||||
# substitues L,k
|
||||
# from this matrix to recover a word
|
||||
words = []
|
||||
sub_len, k = substitutes.size() # sub-len, k
|
||||
@property
|
||||
def score(self):
|
||||
return self.__score
|
||||
|
||||
if sub_len == 0:
|
||||
return words
|
||||
@score.setter
|
||||
def score(self, value):
|
||||
self.__score = value
|
||||
if value >= self.goal_score:
|
||||
self.status = GoalFunctionStatus.SUCCEEDED
|
||||
|
||||
elif sub_len == 1:
|
||||
for (i, j) in zip(substitutes[0], substitutes_score[0]):
|
||||
if threshold != 0 and j < threshold:
|
||||
break
|
||||
words.append(tokenizer._convert_id_to_token(int(i)))
|
||||
else:
|
||||
if use_bpe == 1:
|
||||
words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
|
||||
else:
|
||||
return words
|
||||
def __eq__(self, __o):
|
||||
return self.text == __o.text
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.text)
|
||||
|
||||
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
||||
# # substitutes L, k
|
||||
# # device = mlm_model.device
|
||||
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||||
#
|
||||
# print(words)
|
||||
return words
|
||||
# # find all possible candidates
|
||||
#
|
||||
# all_substitutes = []
|
||||
# for i in range(substitutes.size(0)):
|
||||
# if len(all_substitutes) == 0:
|
||||
# lev_i = substitutes[i]
|
||||
# all_substitutes = [[int(c)] for c in lev_i]
|
||||
# else:
|
||||
# lev_i = []
|
||||
# for all_sub in all_substitutes:
|
||||
# for j in substitutes[i]:
|
||||
# lev_i.append(all_sub + [int(j)])
|
||||
# all_substitutes = lev_i
|
||||
#
|
||||
# # all substitutes list of list of token-id (all candidates)
|
||||
# c_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
# word_list = []
|
||||
# # all_substitutes = all_substitutes[:24]
|
||||
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||||
# all_substitutes = all_substitutes[:24].to(device)
|
||||
# # print(substitutes.size(), all_substitutes.size())
|
||||
# N, L = all_substitutes.size()
|
||||
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
|
||||
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||||
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||||
# _, word_list = torch.sort(ppl)
|
||||
# word_list = [all_substitutes[i] for i in word_list]
|
||||
# final_words = []
|
||||
# for word in word_list:
|
||||
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
|
||||
# text = tokenizer.convert_tokens_to_string(tokens)
|
||||
# final_words.append(text)
|
||||
# return final_words
|
||||
|
||||
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
|
||||
# # substitues L,k
|
||||
# # from this matrix to recover a word
|
||||
# words = []
|
||||
# sub_len, k = substitutes.size() # sub-len, k
|
||||
#
|
||||
# if sub_len == 0:
|
||||
# return words
|
||||
#
|
||||
# elif sub_len == 1:
|
||||
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
|
||||
# if threshold != 0 and j < threshold:
|
||||
# break
|
||||
# words.append(tokenizer._convert_id_to_token(int(i)))
|
||||
# else:
|
||||
# if use_bpe == 1:
|
||||
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
|
||||
# else:
|
||||
# return words
|
||||
# #
|
||||
# # print(words)
|
||||
# return words
|
||||
|
||||
class Trainer(TrainBase):
|
||||
|
||||
|
|
@ -127,6 +158,7 @@ class Trainer(TrainBase):
|
|||
self.clip_tokenizer=Tokenizer()
|
||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
|
||||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
||||
self.attack_thred = self.args.attack_thred
|
||||
# self.run()
|
||||
|
||||
def _init_model(self):
|
||||
|
|
@ -221,6 +253,115 @@ class Trainer(TrainBase):
|
|||
# list of words
|
||||
return masked_words
|
||||
|
||||
def get_transformations(self, text, idx, substitutes):
|
||||
words = text.split(' ')
|
||||
|
||||
trans_text = []
|
||||
for sub in substitutes:
|
||||
words[idx] = sub
|
||||
trans_text.append(' '.join(words))
|
||||
return trans_text
|
||||
|
||||
def get_word_predictions(self, text):
|
||||
_, _, keys = self._tokenize(text)
|
||||
|
||||
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len,
|
||||
truncation=True, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
attention_mask = inputs['attention_mask']
|
||||
with torch.no_grad():
|
||||
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size)
|
||||
|
||||
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1)
|
||||
|
||||
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
|
||||
word_pred_scores_all = word_pred_scores_all[1:-1, :]
|
||||
|
||||
return keys, word_predictions, word_pred_scores_all, attention_mask
|
||||
|
||||
def get_bpe_substitutes(self, substitutes):
|
||||
# substitutes L, k
|
||||
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||||
|
||||
# find all possible candidates
|
||||
all_substitutes = []
|
||||
for i in range(substitutes.size(0)):
|
||||
if len(all_substitutes) == 0:
|
||||
lev_i = substitutes[i]
|
||||
all_substitutes = [[int(c)] for c in lev_i]
|
||||
else:
|
||||
lev_i = []
|
||||
for all_sub in all_substitutes:
|
||||
for j in substitutes[i]:
|
||||
lev_i.append(all_sub + [int(j)])
|
||||
all_substitutes = lev_i
|
||||
|
||||
# all substitutes: list of list of token-id (all candidates)
|
||||
cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
|
||||
|
||||
word_list = []
|
||||
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||||
all_substitutes = all_substitutes[:24].to(self.device)
|
||||
|
||||
N, L = all_substitutes.size()
|
||||
word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size
|
||||
ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||||
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||||
|
||||
_, word_list = torch.sort(ppl)
|
||||
word_list = [all_substitutes[i] for i in word_list]
|
||||
final_words = []
|
||||
for word in word_list:
|
||||
tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word]
|
||||
text = ' '.join([t.strip() for t in tokens])
|
||||
final_words.append(text)
|
||||
return final_words
|
||||
|
||||
def get_substitutes(self, substitutes, substitutes_score, threshold=3.0):
|
||||
ret = []
|
||||
num_sub, _ = substitutes.size()
|
||||
if num_sub == 0:
|
||||
ret = []
|
||||
elif num_sub == 1:
|
||||
for id, score in zip(substitutes[0], substitutes_score[0]):
|
||||
if threshold != 0 and score < threshold:
|
||||
break
|
||||
ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id)))
|
||||
elif self.args.enable_bpe:
|
||||
ret = self.get_bpe_substitutes(substitutes)
|
||||
return ret
|
||||
|
||||
def filter_substitutes(self, substitues):
|
||||
|
||||
ret = []
|
||||
for word in substitues:
|
||||
if word.lower() in filter_words:
|
||||
continue
|
||||
if '##' in word:
|
||||
continue
|
||||
|
||||
ret.append(word)
|
||||
return ret
|
||||
|
||||
def get_goal_results(self, trans_texts, ori_text, ref_text):
|
||||
text_batch = [ref_text] + trans_texts
|
||||
with torch.no_grad():
|
||||
feats = self.feature_extractor(text_batch, device=self.device)
|
||||
feats = feats.flatten(start_dim=1)
|
||||
|
||||
ref_feats = feats[0].unsqueeze(0)
|
||||
trans_feats = feats[1:]
|
||||
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
|
||||
cos_sim = cos_sim.cpu().numpy()
|
||||
|
||||
text_sim = self.get_text_similarity(trans_texts, ori_text)
|
||||
|
||||
results = []
|
||||
for i in range(len(trans_texts)):
|
||||
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
|
||||
continue
|
||||
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
|
||||
return results
|
||||
|
||||
def generate_mapping(self):
|
||||
image_train=[]
|
||||
|
|
@ -246,63 +387,99 @@ class Trainer(TrainBase):
|
|||
return image_representation, image_var_representation
|
||||
|
||||
|
||||
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||
positive_var,
|
||||
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
|
||||
|
||||
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
||||
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
|
||||
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
|
||||
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
||||
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
||||
|
||||
#clean state
|
||||
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||||
final_adverse = []
|
||||
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||||
cur_result = GoalFunctionResult(raw_text)
|
||||
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
|
||||
|
||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
# print(texts)
|
||||
for i, text in enumerate(texts):
|
||||
important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
||||
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
||||
words, sub_words, keys = self._tokenize(text)
|
||||
final_words = copy.deepcopy(words)
|
||||
change = 0
|
||||
for top_index in list_of_index:
|
||||
if change >= self.args.num_perturbation:
|
||||
|
||||
for idx in mask_idx:
|
||||
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
|
||||
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
|
||||
substitutes = self.get_substitutes(predictions, predictions_socre)
|
||||
substitutes = self.filter_substitutes(substitutes)
|
||||
trans_texts = self.get_transformations(raw_text, idx, substitutes)
|
||||
if len(trans_texts) == 0:
|
||||
continue
|
||||
# loss function
|
||||
results = self.get_goal_results(trans_texts, raw_text, positive_code)
|
||||
results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||
|
||||
if len(results) > 0 and results[0].score > cur_result.score:
|
||||
cur_result = results[0]
|
||||
else:
|
||||
continue
|
||||
|
||||
if cur_result.status == GoalFunctionStatus.SUCCEEDED:
|
||||
max_similarity = cur_result.similarity
|
||||
if max_similarity is None:
|
||||
# similarity is not calculated
|
||||
continue
|
||||
|
||||
for result in results[1:]:
|
||||
if result.status != GoalFunctionStatus.SUCCEEDED:
|
||||
break
|
||||
tgt_word = words[top_index[0]]
|
||||
if tgt_word in filter_words:
|
||||
continue
|
||||
if keys[top_index[0]][0] > self.args.max_length - 2:
|
||||
continue
|
||||
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
||||
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
||||
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
||||
self.args.threshold_pred_score)
|
||||
replace_texts = [' '.join(final_words)]
|
||||
available_substitutes = [tgt_word]
|
||||
for substitute_ in substitutes:
|
||||
substitute = substitute_
|
||||
if substitute == tgt_word:
|
||||
continue # filter out original word
|
||||
if '##' in substitute:
|
||||
continue # filter out sub-word
|
||||
if result.similarity > max_similarity:
|
||||
max_similarity = result.similarity
|
||||
cur_result = result
|
||||
return cur_result
|
||||
if cur_result.status == GoalFunctionStatus.SEARCHING:
|
||||
cur_result.status = GoalFunctionStatus.FAILED
|
||||
return cur_result
|
||||
|
||||
if substitute in filter_words:
|
||||
continue
|
||||
temp_replace = copy.deepcopy(final_words)
|
||||
temp_replace[top_index[0]] = substitute
|
||||
available_substitutes.append(substitute)
|
||||
replace_texts.append(' '.join(temp_replace))
|
||||
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||
replace_embeds = self.model.encode_text(replace_text_input)
|
||||
|
||||
loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
||||
loss = loss.sum(dim=-1)
|
||||
candidate_idx = loss.argmax()
|
||||
final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||
if available_substitutes[candidate_idx] != tgt_word:
|
||||
change += 1
|
||||
final_adverse.append(' '.join(final_words))
|
||||
return final_adverse
|
||||
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
||||
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
||||
# words, sub_words, keys = self._tokenize(text)
|
||||
# final_words = copy.deepcopy(words)
|
||||
# change = 0
|
||||
# for top_index in list_of_index:
|
||||
# if change >= self.args.num_perturbation:
|
||||
# break
|
||||
# tgt_word = words[top_index[0]]
|
||||
# if tgt_word in filter_words:
|
||||
# continue
|
||||
# if keys[top_index[0]][0] > self.args.max_length - 2:
|
||||
# continue
|
||||
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
||||
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
||||
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
||||
# self.args.threshold_pred_score)
|
||||
# replace_texts = [' '.join(final_words)]
|
||||
# available_substitutes = [tgt_word]
|
||||
# for substitute_ in substitutes:
|
||||
# substitute = substitute_
|
||||
# if substitute == tgt_word:
|
||||
# continue # filter out original word
|
||||
# if '##' in substitute:
|
||||
# continue # filter out sub-word
|
||||
#
|
||||
# if substitute in filter_words:
|
||||
# continue
|
||||
# temp_replace = copy.deepcopy(final_words)
|
||||
# temp_replace[top_index[0]] = substitute
|
||||
# available_substitutes.append(substitute)
|
||||
# replace_texts.append(' '.join(temp_replace))
|
||||
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||
# replace_embeds = self.model.encode_text(replace_text_input)
|
||||
#
|
||||
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
||||
# loss = loss.sum(dim=-1)
|
||||
# candidate_idx = loss.argmax()
|
||||
# final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||
# if available_substitutes[candidate_idx] != tgt_word:
|
||||
# change += 1
|
||||
# final_adverse.append(' '.join(final_words))
|
||||
# return final_adverse
|
||||
|
||||
def train_epoch(self):
|
||||
# self.change_state(mode="valid")
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def get_args():
|
|||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||
parser.add_argument("--display-step", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=1814)
|
||||
|
||||
parser.add_argument("--attack-thred", type=float, default=0.05)
|
||||
parser.add_argument("--lr", type=float, default=0.001)
|
||||
parser.add_argument("--lr-decay", type=float, default=0.9)
|
||||
parser.add_argument("--clip-lr", type=float, default=0.00001)
|
||||
|
|
|
|||
Loading…
Reference in New Issue