update i2t and t2i

This commit is contained in:
leewlving 2024-06-18 22:19:25 +08:00
parent cb449df1c5
commit 053a58b07a
2 changed files with 38 additions and 24 deletions

View File

@ -136,7 +136,7 @@ class Trainer(TrainBase):
def train_epoch(self):
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
save_dir = os.path.join(self.args.save_dir, "adv_PR_i2t")
all_loss = 0
times = 0
adv_codes=[]

View File

@ -125,7 +125,7 @@ class Trainer(TrainBase):
self.image_var=image_var
self.device=rank
self.clip_tokenizer=Tokenizer()
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder)
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
# self.run()
@ -202,6 +202,20 @@ class Trainer(TrainBase):
image_var_representation[str(centroid.astype(int))]= image_var[i]
return image_representation, image_var_representation
def _tokenize(self, text):
words = text.split(' ')
sub_words = []
keys = []
index = 0
for word in words:
sub = self.bert_tokenizer.tokenize(word)
sub_words += sub
keys.append([index, index + len(sub)])
index += len(sub)
return words, sub_words, keys
def _get_masked(self, text):
words = text.split(' ')
len_text = len(words)
@ -220,7 +234,7 @@ class Trainer(TrainBase):
masked_embeds = []
for i in range(0, len(masked_texts), batch_size):
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
masked_embed = self.model.encode_text(masked_text_input)
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
masked_embeds.append(masked_embed)
masked_embeds = torch.cat(masked_embeds, dim=0)
@ -231,16 +245,15 @@ class Trainer(TrainBase):
return import_scores.sum(dim=-1)
def target_adv(self, index, captions ,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
def target_adv(self, text_tokens, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
texts=[captions[i] for i in index]
texts=[self.clip_tokenizer.decode(token) for token in text_tokens]
text_inputs = self.bert_tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
mlm_logits = self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
#clean state
clean_text=clip.tokenize(texts).to(device, non_blocking=True)
clean_embeds=self.model.encode_text(clean_text)
clean_embeds=self.model.encode_text(text_tokens)
final_adverse = []
for i, text in enumerate(texts):
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
@ -249,19 +262,19 @@ class Trainer(TrainBase):
final_words = copy.deepcopy(words)
change = 0
for top_index in list_of_index:
if change >= self.num_perturbation:
if change >= self.args.num_perturbation:
break
tgt_word = words[top_index[0]]
if tgt_word in filter_words:
continue
if keys[top_index[0]][0] > self.max_length - 2:
if keys[top_index[0]][0] > self.args.max_length - 2:
continue
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
self.threshold_pred_score)
self.args.threshold_pred_score)
replace_texts = [' '.join(final_words)]
@ -280,9 +293,11 @@ class Trainer(TrainBase):
temp_replace[top_index[0]] = substitute
available_substitutes.append(substitute)
replace_texts.append(' '.join(temp_replace))
replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
replace_output = net.inference_text(replace_text_input)
loss = criterion(replace_embeds.log_softmax(dim=-1), origin_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
replace_embeds = self.model.encode_text(replace_text_input)
criterion = torch.nn.KLDivLoss(reduction='none')
loss = criterion(replace_embeds.log_softmax(dim=-1), clean_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
loss = loss.sum(dim=-1)
candidate_idx = loss.argmax()
final_words[top_index[0]] = available_substitutes[candidate_idx]
@ -293,7 +308,7 @@ class Trainer(TrainBase):
def train_epoch(self):
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
save_dir = os.path.join(self.args.save_dir, "adv_PR_t2i")
all_loss = 0
times = 0
adv_codes=[]
@ -325,33 +340,32 @@ class Trainer(TrainBase):
positive_code=self.model.encode_image(target_image)
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
final_adverse=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
adv_code=self.model.encode_text(final_text)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())
adv_img=np.concatenate(adv_codes)
adv_txt=np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_img, _ = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
mAP_t=cal_map(adv_txt,adv_labels,retrieval_img,retrieval_labels)
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = {
'adv_img': adv_img,
'r_txt': retrieval_txt,
'adv_txt': adv_txt,
'r_img': retrieval_img,
'adv_l': adv_labels,
'r_l': retrieval_labels
# 'q_l':query_labels
# 'pr': pr,
# 'pr_t': pr_t
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")