update i2t and t2i
This commit is contained in:
parent
cb449df1c5
commit
053a58b07a
|
|
@ -136,7 +136,7 @@ class Trainer(TrainBase):
|
|||
|
||||
def train_epoch(self):
|
||||
self.change_state(mode="valid")
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_i2t")
|
||||
all_loss = 0
|
||||
times = 0
|
||||
adv_codes=[]
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class Trainer(TrainBase):
|
|||
self.image_var=image_var
|
||||
self.device=rank
|
||||
self.clip_tokenizer=Tokenizer()
|
||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
|
||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
|
||||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
||||
# self.run()
|
||||
|
||||
|
|
@ -202,6 +202,20 @@ class Trainer(TrainBase):
|
|||
image_var_representation[str(centroid.astype(int))]= image_var[i]
|
||||
return image_representation, image_var_representation
|
||||
|
||||
def _tokenize(self, text):
|
||||
words = text.split(' ')
|
||||
|
||||
sub_words = []
|
||||
keys = []
|
||||
index = 0
|
||||
for word in words:
|
||||
sub = self.bert_tokenizer.tokenize(word)
|
||||
sub_words += sub
|
||||
keys.append([index, index + len(sub)])
|
||||
index += len(sub)
|
||||
|
||||
return words, sub_words, keys
|
||||
|
||||
def _get_masked(self, text):
|
||||
words = text.split(' ')
|
||||
len_text = len(words)
|
||||
|
|
@ -220,7 +234,7 @@ class Trainer(TrainBase):
|
|||
masked_embeds = []
|
||||
for i in range(0, len(masked_texts), batch_size):
|
||||
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
|
||||
masked_embed = self.model.encode_text(masked_text_input)
|
||||
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
|
||||
masked_embeds.append(masked_embed)
|
||||
masked_embeds = torch.cat(masked_embeds, dim=0)
|
||||
|
||||
|
|
@ -231,16 +245,15 @@ class Trainer(TrainBase):
|
|||
return import_scores.sum(dim=-1)
|
||||
|
||||
|
||||
def target_adv(self, index, captions ,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
def target_adv(self, text_tokens, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
texts=[captions[i] for i in index]
|
||||
texts=[self.clip_tokenizer.decode(token) for token in text_tokens]
|
||||
text_inputs = self.bert_tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
|
||||
mlm_logits = self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
|
||||
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
|
||||
|
||||
#clean state
|
||||
clean_text=clip.tokenize(texts).to(device, non_blocking=True)
|
||||
clean_embeds=self.model.encode_text(clean_text)
|
||||
clean_embeds=self.model.encode_text(text_tokens)
|
||||
final_adverse = []
|
||||
for i, text in enumerate(texts):
|
||||
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
|
||||
|
|
@ -249,19 +262,19 @@ class Trainer(TrainBase):
|
|||
final_words = copy.deepcopy(words)
|
||||
change = 0
|
||||
for top_index in list_of_index:
|
||||
if change >= self.num_perturbation:
|
||||
if change >= self.args.num_perturbation:
|
||||
break
|
||||
tgt_word = words[top_index[0]]
|
||||
if tgt_word in filter_words:
|
||||
continue
|
||||
if keys[top_index[0]][0] > self.max_length - 2:
|
||||
if keys[top_index[0]][0] > self.args.max_length - 2:
|
||||
continue
|
||||
|
||||
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
|
||||
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
|
||||
|
||||
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
|
||||
self.threshold_pred_score)
|
||||
self.args.threshold_pred_score)
|
||||
|
||||
|
||||
replace_texts = [' '.join(final_words)]
|
||||
|
|
@ -280,9 +293,11 @@ class Trainer(TrainBase):
|
|||
temp_replace[top_index[0]] = substitute
|
||||
available_substitutes.append(substitute)
|
||||
replace_texts.append(' '.join(temp_replace))
|
||||
replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
|
||||
replace_output = net.inference_text(replace_text_input)
|
||||
loss = criterion(replace_embeds.log_softmax(dim=-1), origin_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
|
||||
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||
replace_embeds = self.model.encode_text(replace_text_input)
|
||||
|
||||
criterion = torch.nn.KLDivLoss(reduction='none')
|
||||
loss = criterion(replace_embeds.log_softmax(dim=-1), clean_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
|
||||
loss = loss.sum(dim=-1)
|
||||
candidate_idx = loss.argmax()
|
||||
final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||
|
|
@ -293,7 +308,7 @@ class Trainer(TrainBase):
|
|||
|
||||
def train_epoch(self):
|
||||
self.change_state(mode="valid")
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_t2i")
|
||||
all_loss = 0
|
||||
times = 0
|
||||
adv_codes=[]
|
||||
|
|
@ -325,33 +340,32 @@ class Trainer(TrainBase):
|
|||
positive_code=self.model.encode_image(target_image)
|
||||
|
||||
|
||||
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
final_adverse=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
positive_code,positive_mean,positive_var)
|
||||
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
|
||||
adv_code=self.model.encode_text(final_text)
|
||||
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||
adv_label.append(target_label.numpy())
|
||||
adv_img=np.concatenate(adv_codes)
|
||||
adv_txt=np.concatenate(adv_codes)
|
||||
adv_labels=np.concatenate(adv_label)
|
||||
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
retrieval_img, _ = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
|
||||
|
||||
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
|
||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
||||
mAP_t=cal_map(adv_txt,adv_labels,retrieval_img,retrieval_labels)
|
||||
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
||||
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
|
||||
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
||||
result_dict = {
|
||||
'adv_img': adv_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'adv_txt': adv_txt,
|
||||
'r_img': retrieval_img,
|
||||
'adv_l': adv_labels,
|
||||
'r_l': retrieval_labels
|
||||
# 'q_l':query_labels
|
||||
# 'pr': pr,
|
||||
# 'pr_t': pr_t
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
||||
self.logger.info(">>>>>> save all data!")
|
||||
|
|
|
|||
Loading…
Reference in New Issue