text attack
This commit is contained in:
parent
ebc17c80dc
commit
ca75f34880
|
|
@ -51,67 +51,98 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
|
||||||
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
||||||
filter_words = set(filter_words)
|
filter_words = set(filter_words)
|
||||||
|
|
||||||
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
class GoalFunctionStatus(object):
|
||||||
# substitutes L, k
|
SUCCEEDED = 0 # attack succeeded
|
||||||
# device = mlm_model.device
|
SEARCHING = 1 # In process of searching for a success
|
||||||
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
FAILED = 2 # attack failed
|
||||||
|
|
||||||
# find all possible candidates
|
|
||||||
|
|
||||||
all_substitutes = []
|
class GoalFunctionResult(object):
|
||||||
for i in range(substitutes.size(0)):
|
goal_score = 1
|
||||||
if len(all_substitutes) == 0:
|
|
||||||
lev_i = substitutes[i]
|
|
||||||
all_substitutes = [[int(c)] for c in lev_i]
|
|
||||||
else:
|
|
||||||
lev_i = []
|
|
||||||
for all_sub in all_substitutes:
|
|
||||||
for j in substitutes[i]:
|
|
||||||
lev_i.append(all_sub + [int(j)])
|
|
||||||
all_substitutes = lev_i
|
|
||||||
|
|
||||||
# all substitutes list of list of token-id (all candidates)
|
def __init__(self, text, score=0, similarity=None):
|
||||||
c_loss = nn.CrossEntropyLoss(reduction='none')
|
self.status = GoalFunctionStatus.SEARCHING
|
||||||
word_list = []
|
self.text = text
|
||||||
# all_substitutes = all_substitutes[:24]
|
self.score = score
|
||||||
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
self.similarity = similarity
|
||||||
all_substitutes = all_substitutes[:24].to(device)
|
|
||||||
# print(substitutes.size(), all_substitutes.size())
|
|
||||||
N, L = all_substitutes.size()
|
|
||||||
word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
|
|
||||||
ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
|
||||||
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
|
||||||
_, word_list = torch.sort(ppl)
|
|
||||||
word_list = [all_substitutes[i] for i in word_list]
|
|
||||||
final_words = []
|
|
||||||
for word in word_list:
|
|
||||||
tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
|
|
||||||
text = tokenizer.convert_tokens_to_string(tokens)
|
|
||||||
final_words.append(text)
|
|
||||||
return final_words
|
|
||||||
|
|
||||||
def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
|
@property
|
||||||
# substitues L,k
|
def score(self):
|
||||||
# from this matrix to recover a word
|
return self.__score
|
||||||
words = []
|
|
||||||
sub_len, k = substitutes.size() # sub-len, k
|
|
||||||
|
|
||||||
if sub_len == 0:
|
@score.setter
|
||||||
return words
|
def score(self, value):
|
||||||
|
self.__score = value
|
||||||
|
if value >= self.goal_score:
|
||||||
|
self.status = GoalFunctionStatus.SUCCEEDED
|
||||||
|
|
||||||
elif sub_len == 1:
|
def __eq__(self, __o):
|
||||||
for (i, j) in zip(substitutes[0], substitutes_score[0]):
|
return self.text == __o.text
|
||||||
if threshold != 0 and j < threshold:
|
|
||||||
break
|
def __hash__(self):
|
||||||
words.append(tokenizer._convert_id_to_token(int(i)))
|
return hash(self.text)
|
||||||
else:
|
|
||||||
if use_bpe == 1:
|
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
||||||
words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
|
# # substitutes L, k
|
||||||
else:
|
# # device = mlm_model.device
|
||||||
return words
|
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||||||
#
|
#
|
||||||
# print(words)
|
# # find all possible candidates
|
||||||
return words
|
#
|
||||||
|
# all_substitutes = []
|
||||||
|
# for i in range(substitutes.size(0)):
|
||||||
|
# if len(all_substitutes) == 0:
|
||||||
|
# lev_i = substitutes[i]
|
||||||
|
# all_substitutes = [[int(c)] for c in lev_i]
|
||||||
|
# else:
|
||||||
|
# lev_i = []
|
||||||
|
# for all_sub in all_substitutes:
|
||||||
|
# for j in substitutes[i]:
|
||||||
|
# lev_i.append(all_sub + [int(j)])
|
||||||
|
# all_substitutes = lev_i
|
||||||
|
#
|
||||||
|
# # all substitutes list of list of token-id (all candidates)
|
||||||
|
# c_loss = nn.CrossEntropyLoss(reduction='none')
|
||||||
|
# word_list = []
|
||||||
|
# # all_substitutes = all_substitutes[:24]
|
||||||
|
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||||||
|
# all_substitutes = all_substitutes[:24].to(device)
|
||||||
|
# # print(substitutes.size(), all_substitutes.size())
|
||||||
|
# N, L = all_substitutes.size()
|
||||||
|
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
|
||||||
|
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||||||
|
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||||||
|
# _, word_list = torch.sort(ppl)
|
||||||
|
# word_list = [all_substitutes[i] for i in word_list]
|
||||||
|
# final_words = []
|
||||||
|
# for word in word_list:
|
||||||
|
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
|
||||||
|
# text = tokenizer.convert_tokens_to_string(tokens)
|
||||||
|
# final_words.append(text)
|
||||||
|
# return final_words
|
||||||
|
|
||||||
|
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
|
||||||
|
# # substitues L,k
|
||||||
|
# # from this matrix to recover a word
|
||||||
|
# words = []
|
||||||
|
# sub_len, k = substitutes.size() # sub-len, k
|
||||||
|
#
|
||||||
|
# if sub_len == 0:
|
||||||
|
# return words
|
||||||
|
#
|
||||||
|
# elif sub_len == 1:
|
||||||
|
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
|
||||||
|
# if threshold != 0 and j < threshold:
|
||||||
|
# break
|
||||||
|
# words.append(tokenizer._convert_id_to_token(int(i)))
|
||||||
|
# else:
|
||||||
|
# if use_bpe == 1:
|
||||||
|
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
|
||||||
|
# else:
|
||||||
|
# return words
|
||||||
|
# #
|
||||||
|
# # print(words)
|
||||||
|
# return words
|
||||||
|
|
||||||
class Trainer(TrainBase):
|
class Trainer(TrainBase):
|
||||||
|
|
||||||
|
|
@ -127,6 +158,7 @@ class Trainer(TrainBase):
|
||||||
self.clip_tokenizer=Tokenizer()
|
self.clip_tokenizer=Tokenizer()
|
||||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
|
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
|
||||||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
||||||
|
self.attack_thred = self.args.attack_thred
|
||||||
# self.run()
|
# self.run()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
|
|
@ -221,6 +253,115 @@ class Trainer(TrainBase):
|
||||||
# list of words
|
# list of words
|
||||||
return masked_words
|
return masked_words
|
||||||
|
|
||||||
|
def get_transformations(self, text, idx, substitutes):
|
||||||
|
words = text.split(' ')
|
||||||
|
|
||||||
|
trans_text = []
|
||||||
|
for sub in substitutes:
|
||||||
|
words[idx] = sub
|
||||||
|
trans_text.append(' '.join(words))
|
||||||
|
return trans_text
|
||||||
|
|
||||||
|
def get_word_predictions(self, text):
|
||||||
|
_, _, keys = self._tokenize(text)
|
||||||
|
|
||||||
|
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len,
|
||||||
|
truncation=True, return_tensors="pt")
|
||||||
|
input_ids = inputs["input_ids"].to(self.device)
|
||||||
|
attention_mask = inputs['attention_mask']
|
||||||
|
with torch.no_grad():
|
||||||
|
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size)
|
||||||
|
|
||||||
|
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1)
|
||||||
|
|
||||||
|
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
|
||||||
|
word_pred_scores_all = word_pred_scores_all[1:-1, :]
|
||||||
|
|
||||||
|
return keys, word_predictions, word_pred_scores_all, attention_mask
|
||||||
|
|
||||||
|
def get_bpe_substitutes(self, substitutes):
|
||||||
|
# substitutes L, k
|
||||||
|
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||||||
|
|
||||||
|
# find all possible candidates
|
||||||
|
all_substitutes = []
|
||||||
|
for i in range(substitutes.size(0)):
|
||||||
|
if len(all_substitutes) == 0:
|
||||||
|
lev_i = substitutes[i]
|
||||||
|
all_substitutes = [[int(c)] for c in lev_i]
|
||||||
|
else:
|
||||||
|
lev_i = []
|
||||||
|
for all_sub in all_substitutes:
|
||||||
|
for j in substitutes[i]:
|
||||||
|
lev_i.append(all_sub + [int(j)])
|
||||||
|
all_substitutes = lev_i
|
||||||
|
|
||||||
|
# all substitutes: list of list of token-id (all candidates)
|
||||||
|
cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
|
||||||
|
|
||||||
|
word_list = []
|
||||||
|
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||||||
|
all_substitutes = all_substitutes[:24].to(self.device)
|
||||||
|
|
||||||
|
N, L = all_substitutes.size()
|
||||||
|
word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size
|
||||||
|
ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||||||
|
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||||||
|
|
||||||
|
_, word_list = torch.sort(ppl)
|
||||||
|
word_list = [all_substitutes[i] for i in word_list]
|
||||||
|
final_words = []
|
||||||
|
for word in word_list:
|
||||||
|
tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word]
|
||||||
|
text = ' '.join([t.strip() for t in tokens])
|
||||||
|
final_words.append(text)
|
||||||
|
return final_words
|
||||||
|
|
||||||
|
def get_substitutes(self, substitutes, substitutes_score, threshold=3.0):
|
||||||
|
ret = []
|
||||||
|
num_sub, _ = substitutes.size()
|
||||||
|
if num_sub == 0:
|
||||||
|
ret = []
|
||||||
|
elif num_sub == 1:
|
||||||
|
for id, score in zip(substitutes[0], substitutes_score[0]):
|
||||||
|
if threshold != 0 and score < threshold:
|
||||||
|
break
|
||||||
|
ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id)))
|
||||||
|
elif self.args.enable_bpe:
|
||||||
|
ret = self.get_bpe_substitutes(substitutes)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def filter_substitutes(self, substitues):
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for word in substitues:
|
||||||
|
if word.lower() in filter_words:
|
||||||
|
continue
|
||||||
|
if '##' in word:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ret.append(word)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_goal_results(self, trans_texts, ori_text, ref_text):
|
||||||
|
text_batch = [ref_text] + trans_texts
|
||||||
|
with torch.no_grad():
|
||||||
|
feats = self.feature_extractor(text_batch, device=self.device)
|
||||||
|
feats = feats.flatten(start_dim=1)
|
||||||
|
|
||||||
|
ref_feats = feats[0].unsqueeze(0)
|
||||||
|
trans_feats = feats[1:]
|
||||||
|
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
|
||||||
|
cos_sim = cos_sim.cpu().numpy()
|
||||||
|
|
||||||
|
text_sim = self.get_text_similarity(trans_texts, ori_text)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(len(trans_texts)):
|
||||||
|
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
|
||||||
|
continue
|
||||||
|
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
|
||||||
|
return results
|
||||||
|
|
||||||
def generate_mapping(self):
|
def generate_mapping(self):
|
||||||
image_train=[]
|
image_train=[]
|
||||||
|
|
@ -246,63 +387,99 @@ class Trainer(TrainBase):
|
||||||
return image_representation, image_var_representation
|
return image_representation, image_var_representation
|
||||||
|
|
||||||
|
|
||||||
def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||||
|
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||||
|
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||||
|
positive_var,
|
||||||
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
|
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
|
||||||
|
|
||||||
bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
|
||||||
mlm_logits = self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask).logits
|
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
||||||
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.args.topk, -1)
|
|
||||||
|
|
||||||
#clean state
|
#clean state
|
||||||
clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||||||
final_adverse = []
|
cur_result = GoalFunctionResult(raw_text)
|
||||||
|
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
|
||||||
|
|
||||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
|
||||||
# print(texts)
|
for idx in mask_idx:
|
||||||
for i, text in enumerate(texts):
|
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
|
||||||
important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
|
||||||
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
substitutes = self.get_substitutes(predictions, predictions_socre)
|
||||||
words, sub_words, keys = self._tokenize(text)
|
substitutes = self.filter_substitutes(substitutes)
|
||||||
final_words = copy.deepcopy(words)
|
trans_texts = self.get_transformations(raw_text, idx, substitutes)
|
||||||
change = 0
|
if len(trans_texts) == 0:
|
||||||
for top_index in list_of_index:
|
continue
|
||||||
if change >= self.args.num_perturbation:
|
# loss function
|
||||||
|
results = self.get_goal_results(trans_texts, raw_text, positive_code)
|
||||||
|
results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
if len(results) > 0 and results[0].score > cur_result.score:
|
||||||
|
cur_result = results[0]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cur_result.status == GoalFunctionStatus.SUCCEEDED:
|
||||||
|
max_similarity = cur_result.similarity
|
||||||
|
if max_similarity is None:
|
||||||
|
# similarity is not calculated
|
||||||
|
continue
|
||||||
|
|
||||||
|
for result in results[1:]:
|
||||||
|
if result.status != GoalFunctionStatus.SUCCEEDED:
|
||||||
break
|
break
|
||||||
tgt_word = words[top_index[0]]
|
if result.similarity > max_similarity:
|
||||||
if tgt_word in filter_words:
|
max_similarity = result.similarity
|
||||||
continue
|
cur_result = result
|
||||||
if keys[top_index[0]][0] > self.args.max_length - 2:
|
return cur_result
|
||||||
continue
|
if cur_result.status == GoalFunctionStatus.SEARCHING:
|
||||||
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
cur_result.status = GoalFunctionStatus.FAILED
|
||||||
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
return cur_result
|
||||||
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
|
||||||
self.args.threshold_pred_score)
|
|
||||||
replace_texts = [' '.join(final_words)]
|
|
||||||
available_substitutes = [tgt_word]
|
|
||||||
for substitute_ in substitutes:
|
|
||||||
substitute = substitute_
|
|
||||||
if substitute == tgt_word:
|
|
||||||
continue # filter out original word
|
|
||||||
if '##' in substitute:
|
|
||||||
continue # filter out sub-word
|
|
||||||
|
|
||||||
if substitute in filter_words:
|
|
||||||
continue
|
|
||||||
temp_replace = copy.deepcopy(final_words)
|
|
||||||
temp_replace[top_index[0]] = substitute
|
|
||||||
available_substitutes.append(substitute)
|
|
||||||
replace_texts.append(' '.join(temp_replace))
|
|
||||||
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
|
||||||
replace_embeds = self.model.encode_text(replace_text_input)
|
|
||||||
|
|
||||||
loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
|
||||||
loss = loss.sum(dim=-1)
|
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
|
||||||
candidate_idx = loss.argmax()
|
# words, sub_words, keys = self._tokenize(text)
|
||||||
final_words[top_index[0]] = available_substitutes[candidate_idx]
|
# final_words = copy.deepcopy(words)
|
||||||
if available_substitutes[candidate_idx] != tgt_word:
|
# change = 0
|
||||||
change += 1
|
# for top_index in list_of_index:
|
||||||
final_adverse.append(' '.join(final_words))
|
# if change >= self.args.num_perturbation:
|
||||||
return final_adverse
|
# break
|
||||||
|
# tgt_word = words[top_index[0]]
|
||||||
|
# if tgt_word in filter_words:
|
||||||
|
# continue
|
||||||
|
# if keys[top_index[0]][0] > self.args.max_length - 2:
|
||||||
|
# continue
|
||||||
|
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
||||||
|
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
||||||
|
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
||||||
|
# self.args.threshold_pred_score)
|
||||||
|
# replace_texts = [' '.join(final_words)]
|
||||||
|
# available_substitutes = [tgt_word]
|
||||||
|
# for substitute_ in substitutes:
|
||||||
|
# substitute = substitute_
|
||||||
|
# if substitute == tgt_word:
|
||||||
|
# continue # filter out original word
|
||||||
|
# if '##' in substitute:
|
||||||
|
# continue # filter out sub-word
|
||||||
|
#
|
||||||
|
# if substitute in filter_words:
|
||||||
|
# continue
|
||||||
|
# temp_replace = copy.deepcopy(final_words)
|
||||||
|
# temp_replace[top_index[0]] = substitute
|
||||||
|
# available_substitutes.append(substitute)
|
||||||
|
# replace_texts.append(' '.join(temp_replace))
|
||||||
|
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||||
|
# replace_embeds = self.model.encode_text(replace_text_input)
|
||||||
|
#
|
||||||
|
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
||||||
|
# loss = loss.sum(dim=-1)
|
||||||
|
# candidate_idx = loss.argmax()
|
||||||
|
# final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||||
|
# if available_substitutes[candidate_idx] != tgt_word:
|
||||||
|
# change += 1
|
||||||
|
# final_adverse.append(' '.join(final_words))
|
||||||
|
# return final_adverse
|
||||||
|
|
||||||
def train_epoch(self):
|
def train_epoch(self):
|
||||||
# self.change_state(mode="valid")
|
# self.change_state(mode="valid")
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def get_args():
|
||||||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||||
parser.add_argument("--display-step", type=int, default=50)
|
parser.add_argument("--display-step", type=int, default=50)
|
||||||
parser.add_argument("--seed", type=int, default=1814)
|
parser.add_argument("--seed", type=int, default=1814)
|
||||||
|
parser.add_argument("--attack-thred", type=float, default=0.05)
|
||||||
parser.add_argument("--lr", type=float, default=0.001)
|
parser.add_argument("--lr", type=float, default=0.001)
|
||||||
parser.add_argument("--lr-decay", type=float, default=0.9)
|
parser.add_argument("--lr-decay", type=float, default=0.9)
|
||||||
parser.add_argument("--clip-lr", type=float, default=0.00001)
|
parser.add_argument("--clip-lr", type=float, default=0.00001)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue