advclip/train/text_train.py

683 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
from torch.nn.modules import loss
# from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
import scipy.io as scio
import numpy as np
from .base import TrainBase
from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
import clip
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import BertForMaskedLM
from model.bert_tokenizer import BertTokenizer
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
"didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
"mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words)
class GoalFunctionStatus(object):
SUCCEEDED = 0 # attack succeeded
SEARCHING = 1 # In process of searching for a success
FAILED = 2 # attack failed
class GoalFunctionResult(object):
goal_score = 1
def __init__(self, text, score=0, similarity=None):
self.status = GoalFunctionStatus.SEARCHING
self.text = text
self.score = score
self.similarity = similarity
@property
def score(self):
return self.__score
@score.setter
def score(self, value):
self.__score = value
if value >= self.goal_score:
self.status = GoalFunctionStatus.SUCCEEDED
def __eq__(self, __o):
return self.text == __o.text
def __hash__(self):
return hash(self.text)
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# # substitutes L, k
# # device = mlm_model.device
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
#
# # find all possible candidates
#
# all_substitutes = []
# for i in range(substitutes.size(0)):
# if len(all_substitutes) == 0:
# lev_i = substitutes[i]
# all_substitutes = [[int(c)] for c in lev_i]
# else:
# lev_i = []
# for all_sub in all_substitutes:
# for j in substitutes[i]:
# lev_i.append(all_sub + [int(j)])
# all_substitutes = lev_i
#
# # all substitutes list of list of token-id (all candidates)
# c_loss = nn.CrossEntropyLoss(reduction='none')
# word_list = []
# # all_substitutes = all_substitutes[:24]
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
# all_substitutes = all_substitutes[:24].to(device)
# # print(substitutes.size(), all_substitutes.size())
# N, L = all_substitutes.size()
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
# _, word_list = torch.sort(ppl)
# word_list = [all_substitutes[i] for i in word_list]
# final_words = []
# for word in word_list:
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
# text = tokenizer.convert_tokens_to_string(tokens)
# final_words.append(text)
# return final_words
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
# # substitues L,k
# # from this matrix to recover a word
# words = []
# sub_len, k = substitutes.size() # sub-len, k
#
# if sub_len == 0:
# return words
#
# elif sub_len == 1:
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
# if threshold != 0 and j < threshold:
# break
# words.append(tokenizer._convert_id_to_token(int(i)))
# else:
# if use_bpe == 1:
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
# else:
# return words
# #
# # print(words)
# return words
class Trainer(TrainBase):
def __init__(self,
rank=0):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
image_mean, image_var=self.generate_mapping()
self.image_mean=image_mean
self.image_var=image_var
self.device=rank
self.clip_tokenizer=Tokenizer()
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
self.attack_thred = self.args.attack_thred
# self.run()
def _init_model(self):
self.logger.info("init model.")
model_clip, preprocess = clip.load(self.args.victim, device=device)
self.model= model_clip
self.model.eval()
self.model.float()
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.train_data=train_data
def _tokenize(self, text):
words = text.split(' ')
sub_words = []
keys = []
index = 0
for word in words:
sub = self.bert_tokenizer.tokenize(word)
sub_words += sub
keys.append([index, index + len(sub)])
index += len(sub)
return words, sub_words, keys
def get_important_scores(self, text, origin_embeds, batch_size, max_length):
# device = origin_embeds.device
masked_words = self._get_masked(text)
masked_texts = [' '.join(words) for words in masked_words] # list of text of masked words
masked_embeds = []
for i in range(0, len(masked_texts), batch_size):
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
masked_embeds.append(masked_embed)
masked_embeds = torch.cat(masked_embeds, dim=0)
criterion = torch.nn.KLDivLoss(reduction='none')
import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
return import_scores.sum(dim=-1)
def _get_masked(self, text):
words = text.split(' ')
len_text = len(words)
masked_words = []
for i in range(len_text):
masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:])
# list of words
return masked_words
def get_transformations(self, text, idx, substitutes):
words = text.split(' ')
trans_text = []
for sub in substitutes:
words[idx] = sub
trans_text.append(' '.join(words))
return trans_text
def get_word_predictions(self, text):
_, _, keys = self._tokenize(text)
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len,
truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs['attention_mask']
with torch.no_grad():
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size)
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1)
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
word_pred_scores_all = word_pred_scores_all[1:-1, :]
return keys, word_predictions, word_pred_scores_all, attention_mask
def get_bpe_substitutes(self, substitutes):
# substitutes L, k
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
# find all possible candidates
all_substitutes = []
for i in range(substitutes.size(0)):
if len(all_substitutes) == 0:
lev_i = substitutes[i]
all_substitutes = [[int(c)] for c in lev_i]
else:
lev_i = []
for all_sub in all_substitutes:
for j in substitutes[i]:
lev_i.append(all_sub + [int(j)])
all_substitutes = lev_i
# all substitutes: list of list of token-id (all candidates)
cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
word_list = []
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
all_substitutes = all_substitutes[:24].to(self.device)
N, L = all_substitutes.size()
word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size
ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
_, word_list = torch.sort(ppl)
word_list = [all_substitutes[i] for i in word_list]
final_words = []
for word in word_list:
tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word]
text = ' '.join([t.strip() for t in tokens])
final_words.append(text)
return final_words
def get_substitutes(self, substitutes, substitutes_score, threshold=3.0):
ret = []
num_sub, _ = substitutes.size()
if num_sub == 0:
ret = []
elif num_sub == 1:
for id, score in zip(substitutes[0], substitutes_score[0]):
if threshold != 0 and score < threshold:
break
ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id)))
elif self.args.enable_bpe:
ret = self.get_bpe_substitutes(substitutes)
return ret
def filter_substitutes(self, substitues):
ret = []
for word in substitues:
if word.lower() in filter_words:
continue
if '##' in word:
continue
ret.append(word)
return ret
def get_goal_results(self, trans_texts, ori_text, ref_text):
text_batch = [ref_text] + trans_texts
with torch.no_grad():
feats = self.feature_extractor(text_batch, device=self.device)
feats = feats.flatten(start_dim=1)
ref_feats = feats[0].unsqueeze(0)
trans_feats = feats[1:]
cos_sim = F.cosine_similarity(ref_feats.repeat(trans_feats.size(0), 1), trans_feats)
cos_sim = cos_sim.cpu().numpy()
text_sim = self.get_text_similarity(trans_texts, ori_text)
results = []
for i in range(len(trans_texts)):
if text_sim[i] is not None and text_sim[i] < self.semantic_thred:
continue
results.append(GoalFunctionResult(trans_texts[i], score=cos_sim[i], similarity=text_sim[i]))
return results
def generate_mapping(self):
image_train=[]
label_train=[]
for image, text, label, index in self.train_loader:
# raw_text=[self.clip_tokenizer.decode(token) for token in text]
image=image.to(device, non_blocking=True)
# print(self.model.vocab_size)
temp_image=self.model.encode_image(image)
image_train.append(temp_image.cpu().detach().numpy())
label_train.append(label.detach().numpy())
image_train=np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
image_representation = {}
image_var_representation = {}
for i, centroid in enumerate(label_unipue):
image_representation[str(centroid.astype(int))] = image_centroids[i]
image_var_representation[str(centroid.astype(int))]= image_var[i]
return image_representation, image_var_representation
# def target_adv(self, texts, raw_text, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
# beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
def target_adv(self, texts, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var,
beta=10, epsilon=0.03125, alpha=3 / 255, num_iter=1500, temperature=0.05):
# bert_inputs=self.bert_tokenizer(raw_text, padding='max_length', truncation=True, max_length=self.args.max_words, return_tensors='pt')
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
#clean state
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
cur_result = GoalFunctionResult(raw_text)
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
for idx in mask_idx:
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
substitutes = self.get_substitutes(predictions, predictions_socre)
substitutes = self.filter_substitutes(substitutes)
trans_texts = self.get_transformations(raw_text, idx, substitutes)
if len(trans_texts) == 0:
continue
# loss function
results = self.get_goal_results(trans_texts, raw_text, positive_code)
results = sorted(results, key=lambda x: x.score, reverse=True)
if len(results) > 0 and results[0].score > cur_result.score:
cur_result = results[0]
else:
continue
if cur_result.status == GoalFunctionStatus.SUCCEEDED:
max_similarity = cur_result.similarity
if max_similarity is None:
# similarity is not calculated
continue
for result in results[1:]:
if result.status != GoalFunctionStatus.SUCCEEDED:
break
if result.similarity > max_similarity:
max_similarity = result.similarity
cur_result = result
return cur_result
if cur_result.status == GoalFunctionStatus.SEARCHING:
cur_result.status = GoalFunctionStatus.FAILED
return cur_result
# important_scores = self.get_important_scores(text, clean_embeds, self.args.batch_size, self.args.max_words)
# list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
# words, sub_words, keys = self._tokenize(text)
# final_words = copy.deepcopy(words)
# change = 0
# for top_index in list_of_index:
# if change >= self.args.num_perturbation:
# break
# tgt_word = words[top_index[0]]
# if tgt_word in filter_words:
# continue
# if keys[top_index[0]][0] > self.args.max_length - 2:
# continue
# substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
# word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
# substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
# self.args.threshold_pred_score)
# replace_texts = [' '.join(final_words)]
# available_substitutes = [tgt_word]
# for substitute_ in substitutes:
# substitute = substitute_
# if substitute == tgt_word:
# continue # filter out original word
# if '##' in substitute:
# continue # filter out sub-word
#
# if substitute in filter_words:
# continue
# temp_replace = copy.deepcopy(final_words)
# temp_replace[top_index[0]] = substitute
# available_substitutes.append(substitute)
# replace_texts.append(' '.join(temp_replace))
# replace_text_input = self.clip_tokenizer(replace_texts).to(device)
# replace_embeds = self.model.encode_text(replace_text_input)
#
# loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
# loss = loss.sum(dim=-1)
# candidate_idx = loss.argmax()
# final_words[top_index[0]] = available_substitutes[candidate_idx]
# if available_substitutes[candidate_idx] != tgt_word:
# change += 1
# final_adverse.append(' '.join(final_words))
# return final_adverse
def train_epoch(self):
# self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
all_loss = 0
times = 0
adv_codes=[]
adv_label=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
print(times)
image.float()
raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_image(image)
#targeted sample
np.random.seed(times)
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
target_dataset = data.Subset(self.train_data, select_index)
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
target_image, _, target_label, _ = next(iter(target_subset))
target_image=target_image.to(self.rank, non_blocking=True)
positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
positive_code=self.model.encode_image(target_image)
final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
adv_code=self.model.encode_text(final_text)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())
adv_img=np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = {
'adv_img': adv_img,
'r_txt': retrieval_txt,
'adv_l': adv_labels,
'r_l': retrieval_labels
# 'q_l':query_labels
# 'pr': pr,
# 'pr_t': pr_t
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def train(self):
self.logger.info("Start train.")
for epoch in range(self.args.epochs):
self.train_epoch(epoch)
self.valid(epoch)
self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def make_hash_code(self, code: list) -> torch.Tensor:
code = torch.stack(code)
# print(code.shape)
code = code.permute(1, 0, 2)
hash_code = torch.argmax(code, dim=-1)
hash_code[torch.where(hash_code == 0)] = -1
hash_code = hash_code.float()
return hash_code
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.device, non_blocking=True)
text = text.to(self.device, non_blocking=True)
index = index.numpy()
with torch.no_grad():
image_feature = self.model.encode_image(image)
text_features = self.model.encode_text(text)
img_buffer[index, :] = image_feature.detach()
text_buffer[index, :] = text_features.detach()
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
self.logger.info("Valid Clean.")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
# def valid(self, epoch):
# self.logger.info("Valid.")
# self.change_state(mode="valid")
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# # print("get all code")
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# # print("map map")
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# if self.max_mapi2t < mAPi2t:
# self.best_epoch_i = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
# if self.max_mapt2i < mAPt2i:
# self.best_epoch_t = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")