561 lines
26 KiB
Python
561 lines
26 KiB
Python
import copy
|
||
from torch.nn.modules import loss
|
||
# from model.hash_model import DCMHT as DCMHT
|
||
import os
|
||
from tqdm import tqdm
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import DataLoader
|
||
import torch.utils.data as data
|
||
import scipy.io as scio
|
||
import numpy as np
|
||
|
||
from .base import TrainBase
|
||
from torch.nn import functional as F
|
||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity, find_indices
|
||
from utils.calc_utils import cal_map, cal_pr
|
||
from dataset.dataloader import dataloader
|
||
import clip
|
||
import re
|
||
from transformers import BertForMaskedLM
|
||
from model.bert_tokenizer import BertTokenizer
|
||
from transformers import AutoTokenizer
|
||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
||
filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
|
||
'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
|
||
'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
|
||
'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
|
||
'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
|
||
"didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
|
||
'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
|
||
'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
|
||
'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
|
||
'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
|
||
'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
|
||
"mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
|
||
'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
|
||
'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
|
||
'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
|
||
'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
|
||
'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
|
||
'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
|
||
'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
|
||
'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
|
||
'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
|
||
'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
|
||
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
|
||
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
|
||
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
|
||
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
||
filter_words = set(filter_words)
|
||
|
||
|
||
def text_filter(text):
|
||
text = re.findall(r"<|startoftext|>(.+)<|endoftext|>", text)
|
||
text = text[2]
|
||
text = re.sub(r'</w>', ' ', text)
|
||
return text
|
||
|
||
|
||
class GoalFunctionStatus(object):
|
||
SUCCEEDED = 0 # attack succeeded
|
||
SEARCHING = 1 # In process of searching for a success
|
||
FAILED = 2 # attack failed
|
||
|
||
|
||
class GoalFunctionResult(object):
|
||
goal_score = 1
|
||
|
||
def __init__(self, text, score=0, similarity=None):
|
||
self.status = GoalFunctionStatus.SEARCHING
|
||
self.text = text
|
||
self.score = score
|
||
self.similarity = similarity
|
||
|
||
@property
|
||
def score(self):
|
||
return self.__score
|
||
|
||
@score.setter
|
||
def score(self, value):
|
||
self.__score = value
|
||
if value >= self.goal_score:
|
||
self.status = GoalFunctionStatus.SUCCEEDED
|
||
|
||
def __eq__(self, __o):
|
||
return self.text == __o.text
|
||
|
||
def __hash__(self):
|
||
return hash(self.text)
|
||
|
||
|
||
class Trainer(TrainBase):
|
||
|
||
def __init__(self,
|
||
rank=0):
|
||
args = get_args()
|
||
super(Trainer, self).__init__(args, rank)
|
||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||
image_mean, image_var = self.generate_mapping()
|
||
self.image_mean = image_mean
|
||
self.image_var = image_var
|
||
self.device = rank
|
||
self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77,
|
||
truncation=True)
|
||
self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.text_encoder, do_lower_case=True)
|
||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder).to(device)
|
||
self.attack_thred = self.args.attack_thred
|
||
# self.run()
|
||
|
||
def _init_model(self):
|
||
self.logger.info("init model.")
|
||
model_clip, preprocess = clip.load(self.args.victim, device=device)
|
||
self.model = model_clip
|
||
self.model.eval()
|
||
self.model.float()
|
||
|
||
def _init_dataset(self):
|
||
self.logger.info("init dataset.")
|
||
self.logger.info(f"Using {self.args.dataset} dataset.")
|
||
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
||
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
||
indexFile=self.args.index_file,
|
||
labelFile=self.args.label_file,
|
||
maxWords=self.args.max_words,
|
||
imageResolution=self.args.resolution,
|
||
query_num=self.args.query_num,
|
||
train_num=self.args.train_num,
|
||
seed=self.args.seed)
|
||
self.train_labels = train_data.get_all_label()
|
||
self.query_labels = query_data.get_all_label()
|
||
self.retrieval_labels = retrieval_data.get_all_label()
|
||
self.args.retrieval_num = len(self.retrieval_labels)
|
||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
||
self.train_loader = DataLoader(
|
||
dataset=train_data,
|
||
batch_size=self.args.batch_size,
|
||
num_workers=self.args.num_workers,
|
||
pin_memory=True,
|
||
shuffle=True
|
||
)
|
||
self.query_loader = DataLoader(
|
||
dataset=query_data,
|
||
batch_size=self.args.batch_size,
|
||
num_workers=self.args.num_workers,
|
||
pin_memory=True,
|
||
shuffle=True
|
||
)
|
||
self.retrieval_loader = DataLoader(
|
||
dataset=retrieval_data,
|
||
batch_size=self.args.batch_size,
|
||
num_workers=self.args.num_workers,
|
||
pin_memory=True,
|
||
shuffle=True
|
||
)
|
||
self.train_data = train_data
|
||
|
||
def _tokenize(self, text):
|
||
words = text.split(' ')
|
||
|
||
sub_words = []
|
||
keys = []
|
||
index = 0
|
||
for word in words:
|
||
sub = self.bert_tokenizer.tokenize(word)
|
||
sub_words += sub
|
||
keys.append([index, index + len(sub)])
|
||
index += len(sub)
|
||
|
||
return words, sub_words, keys
|
||
|
||
def get_important_scores(self, text, origin_embeds, batch_size, max_length):
|
||
# device = origin_embeds.device
|
||
|
||
masked_words = self._get_masked(text)
|
||
masked_texts = [' '.join(words) for words in masked_words] # list of text of masked words
|
||
|
||
masked_embeds = []
|
||
for i in range(0, len(masked_texts), batch_size):
|
||
masked_text_input = self.bert_tokenizer(masked_texts[i:i + batch_size], padding='max_length',
|
||
truncation=True, max_length=max_length, return_tensors='pt').to(
|
||
device)
|
||
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
|
||
masked_embeds.append(masked_embed)
|
||
masked_embeds = torch.cat(masked_embeds, dim=0)
|
||
|
||
criterion = torch.nn.KLDivLoss(reduction='none')
|
||
|
||
import_scores = criterion(masked_embeds.log_softmax(dim=-1),
|
||
origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
|
||
|
||
return import_scores.sum(dim=-1)
|
||
|
||
def _get_masked(self, text):
|
||
words = text.split(' ')
|
||
len_text = len(words)
|
||
masked_words = []
|
||
for i in range(len_text):
|
||
masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:])
|
||
# list of words
|
||
return masked_words
|
||
|
||
def get_transformations(self, text, idx, substitutes):
|
||
words = text.split(' ')
|
||
|
||
trans_text = []
|
||
for sub in substitutes:
|
||
words[idx] = sub
|
||
trans_text.append(' '.join(words))
|
||
return trans_text
|
||
|
||
def get_word_predictions(self, text):
|
||
_, _, keys = self._tokenize(text)
|
||
|
||
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.args.max_words,
|
||
truncation=True, return_tensors="pt")
|
||
input_ids = inputs["input_ids"].to(self.device)
|
||
attention_mask = inputs['attention_mask']
|
||
with torch.no_grad():
|
||
word_predictions = self.ref_net(input_ids)['logits'].squeeze(0) # (seq_len, vocab_size)
|
||
# print(self.ref_net(input_ids)['logits'].shape)
|
||
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.args.max_candidate, -1)
|
||
|
||
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
|
||
word_pred_scores_all = word_pred_scores_all[1:-1, :]
|
||
|
||
return keys, word_predictions, word_pred_scores_all, attention_mask
|
||
|
||
def get_bpe_substitutes(self, substitutes):
|
||
# substitutes L, k
|
||
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
||
|
||
# find all possible candidates
|
||
all_substitutes = []
|
||
for i in range(substitutes.size(0)):
|
||
if len(all_substitutes) == 0:
|
||
lev_i = substitutes[i]
|
||
all_substitutes = [[int(c)] for c in lev_i]
|
||
else:
|
||
lev_i = []
|
||
for all_sub in all_substitutes:
|
||
for j in substitutes[i]:
|
||
lev_i.append(all_sub + [int(j)])
|
||
all_substitutes = lev_i
|
||
|
||
# all substitutes: list of list of token-id (all candidates)
|
||
cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
|
||
|
||
word_list = []
|
||
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
||
all_substitutes = all_substitutes[:24].to(self.device)
|
||
|
||
N, L = all_substitutes.size()
|
||
word_predictions = self.ref_net(all_substitutes)[0] # N L vocab-size
|
||
ppl = cross_entropy_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
||
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
||
|
||
_, word_list = torch.sort(ppl)
|
||
word_list = [all_substitutes[i] for i in word_list]
|
||
final_words = []
|
||
for word in word_list:
|
||
tokens = [self.bert_tokenizer.convert_ids_to_tokens(int(i)) for i in word]
|
||
text = ' '.join([t.strip() for t in tokens])
|
||
final_words.append(text)
|
||
return final_words
|
||
|
||
def get_substitutes(self, substitutes, substitutes_score, threshold=3.0):
|
||
ret = []
|
||
num_sub, _ = substitutes.size()
|
||
if num_sub == 0:
|
||
ret = []
|
||
elif num_sub == 1:
|
||
for id, score in zip(substitutes[0], substitutes_score[0]):
|
||
if threshold != 0 and score < threshold:
|
||
break
|
||
ret.append(self.bert_tokenizer.convert_ids_to_tokens(int(id)))
|
||
elif self.args.enable_bpe:
|
||
ret = self.get_bpe_substitutes(substitutes)
|
||
return ret
|
||
|
||
def filter_substitutes(self, substitues):
|
||
|
||
ret = []
|
||
for word in substitues:
|
||
if word.lower() in filter_words:
|
||
continue
|
||
if '##' in word:
|
||
continue
|
||
|
||
ret.append(word)
|
||
return ret
|
||
|
||
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||
positive_var, beta=10, temperature=0.05):
|
||
# print(trans_texts)
|
||
trans_feature = clip.tokenize(trans_texts,context_length=77,truncate=True).to(device)
|
||
anchor = self.model.encode_text(trans_feature)
|
||
batch_size=anchor.shape[0]
|
||
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code.repeat(batch_size, 1), negetive_code.repeat(batch_size, 1),
|
||
distance_function=nn.CosineSimilarity(), reduction='none')
|
||
sim=F.cosine_similarity(anchor,positive_code.unsqueeze(0), dim=1, eps=1e-8).unsqueeze(1)
|
||
negative_dist = (anchor - negetive_mean) ** 2 / (negative_var+ 1e-5)
|
||
positive_dist = (anchor - positive_mean) ** 2 / (positive_var+ 1e-5)
|
||
negatives = torch.exp(negative_dist / temperature)
|
||
positives = torch.exp(positive_dist / temperature)
|
||
loss = torch.log(positives / (positives + negatives)).mean(dim=1, keepdim=True) + beta * loss1
|
||
results = []
|
||
# print(loss.shape)
|
||
for i in range(len(trans_texts)):
|
||
if loss[i].shape[0] >1 or sim[i] >self.args.sim_threshold:
|
||
continue
|
||
results.append(GoalFunctionResult(trans_texts[i], score=loss[i], similarity=sim[i]))
|
||
return results
|
||
|
||
def generate_mapping(self):
|
||
image_train = []
|
||
label_train = []
|
||
for image, text, label, index in self.train_loader:
|
||
# raw_text=[self.clip_tokenizer.decode(token) for token in text]
|
||
image = image.to(device, non_blocking=True)
|
||
# print(self.model.vocab_size)
|
||
temp_image = self.model.encode_image(image)
|
||
image_train.append(temp_image.cpu().detach().numpy())
|
||
label_train.append(label.detach().numpy())
|
||
image_train = np.concatenate(image_train, axis=0)
|
||
label_train = np.concatenate(label_train, axis=0)
|
||
label_unipue = np.unique(label_train, axis=0)
|
||
image_centroids = np.stack(
|
||
[image_train[find_indices(label_train, label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))],
|
||
axis=0)
|
||
image_var = np.stack(
|
||
[image_train[find_indices(label_train, label_unipue[i])].var(axis=0) for i in range(len(label_unipue))],
|
||
axis=0)
|
||
|
||
image_representation = {}
|
||
image_var_representation = {}
|
||
for i, centroid in enumerate(label_unipue):
|
||
image_representation[str(centroid.astype(int))] = image_centroids[i]
|
||
image_var_representation[str(centroid.astype(int))] = image_var[i]
|
||
return image_representation, image_var_representation
|
||
|
||
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||
positive_var, beta=10, temperature=0.05):
|
||
# print(raw_text)
|
||
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
||
|
||
#clean state
|
||
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||
cur_result = GoalFunctionResult(raw_text)
|
||
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
|
||
|
||
for idx in mask_idx:
|
||
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
|
||
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
|
||
substitutes = self.get_substitutes(predictions, predictions_socre)
|
||
substitutes = self.filter_substitutes(substitutes)
|
||
trans_texts = self.get_transformations(raw_text, idx, substitutes)
|
||
if len(trans_texts) == 0:
|
||
continue
|
||
# loss function
|
||
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
|
||
positive_mean, positive_var, beta, temperature)
|
||
results = sorted(results, key=lambda x: x.score, reverse=True)
|
||
|
||
if len(results) > 0 and results[0].score > cur_result.score:
|
||
cur_result = results[0]
|
||
else:
|
||
continue
|
||
|
||
if cur_result.status == GoalFunctionStatus.SUCCEEDED:
|
||
max_similarity = cur_result.similarity
|
||
if max_similarity is None:
|
||
# similarity is not calculated
|
||
continue
|
||
|
||
for result in results[1:]:
|
||
if result.status != GoalFunctionStatus.SUCCEEDED:
|
||
break
|
||
if result.similarity > max_similarity:
|
||
max_similarity = result.similarity
|
||
cur_result = result
|
||
return cur_result
|
||
if cur_result.status == GoalFunctionStatus.SEARCHING:
|
||
cur_result.status = GoalFunctionStatus.FAILED
|
||
return cur_result
|
||
|
||
def train_epoch(self):
|
||
# self.change_state(mode="valid")
|
||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||
all_loss = 0
|
||
times = 0
|
||
adv_codes = []
|
||
adv_label = []
|
||
for image, text, label, index in self.train_loader:
|
||
self.global_step += 1
|
||
times += 1
|
||
print(times)
|
||
image.float()
|
||
|
||
image = image.to(self.rank, non_blocking=True)
|
||
text = text.to(self.rank, non_blocking=True)
|
||
negetive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||
negative_var = np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||
negetive_mean = torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||
negative_var = torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
||
negetive_code = self.model.encode_image(image)
|
||
|
||
#targeted sample
|
||
np.random.seed(times)
|
||
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
|
||
target_dataset = data.Subset(self.train_data, select_index)
|
||
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
|
||
target_image, _, target_label, _ = next(iter(target_subset))
|
||
target_image = target_image.to(self.rank, non_blocking=True)
|
||
positive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||
positive_var = np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||
positive_mean = torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
|
||
positive_var = torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
|
||
positive_code = self.model.encode_image(target_image)
|
||
# print(self.clip_tokenizer.my_encode('This day is good!'))
|
||
raw_text = [self.clip_tokenizer.convert_ids_to_tokens(token.cpu()) for token in text]
|
||
raw_text = [text_filter(self.clip_tokenizer.convert_tokens_to_string(txt)) for txt in raw_text]
|
||
final_texts=[]
|
||
for i in range(self.args.batch_size):
|
||
adv_txt=self.target_adv( raw_text[i], negetive_code[i], negetive_mean[i], negative_var[i],
|
||
positive_code[i], positive_mean[i], positive_var[i])
|
||
final_texts.append(adv_txt.text)
|
||
# final_adverse = self.target_adv( raw_text, negetive_code, negetive_mean, negative_var,
|
||
# positive_code, positive_mean, positive_var)
|
||
final_text = clip.tokenize(final_texts,context_length=77,truncate=True).to(self.rank, non_blocking=True)
|
||
adv_code = self.model.encode_text(final_text)
|
||
adv_codes.append(adv_code.cpu().detach().numpy())
|
||
adv_label.append(target_label.numpy())
|
||
adv_img = np.concatenate(adv_codes)
|
||
adv_labels = np.concatenate(adv_label)
|
||
|
||
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||
|
||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||
retrieval_labels = self.retrieval_labels.numpy()
|
||
|
||
mAP_t = cal_map(adv_img, adv_labels, retrieval_txt, retrieval_labels)
|
||
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
||
result_dict = {
|
||
'adv_img': adv_img,
|
||
'r_txt': retrieval_txt,
|
||
'adv_l': adv_labels,
|
||
'r_l': retrieval_labels
|
||
# 'q_l':query_labels
|
||
# 'pr': pr,
|
||
# 'pr_t': pr_t
|
||
}
|
||
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"),
|
||
result_dict)
|
||
self.logger.info(">>>>>> save all data!")
|
||
|
||
def train(self):
|
||
self.logger.info("Start train.")
|
||
|
||
for epoch in range(self.args.epochs):
|
||
self.train_epoch(epoch)
|
||
self.valid(epoch)
|
||
self.save_model(epoch)
|
||
|
||
self.logger.info(
|
||
f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||
|
||
def make_hash_code(self, code: list) -> torch.Tensor:
|
||
|
||
code = torch.stack(code)
|
||
# print(code.shape)
|
||
code = code.permute(1, 0, 2)
|
||
hash_code = torch.argmax(code, dim=-1)
|
||
hash_code[torch.where(hash_code == 0)] = -1
|
||
hash_code = hash_code.float()
|
||
|
||
return hash_code
|
||
|
||
def get_code(self, data_loader, length: int):
|
||
|
||
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||
|
||
for image, text, label, index in tqdm(data_loader):
|
||
image = image.to(self.device, non_blocking=True)
|
||
text = text.to(self.device, non_blocking=True)
|
||
index = index.numpy()
|
||
with torch.no_grad():
|
||
image_feature = self.model.encode_image(image)
|
||
text_features = self.model.encode_text(text)
|
||
img_buffer[index, :] = image_feature.detach()
|
||
text_buffer[index, :] = text_features.detach()
|
||
|
||
return img_buffer, text_buffer # img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||
|
||
def valid_attack(self, adv_images, texts, adv_labels):
|
||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
def test(self, mode_name="i2t"):
|
||
self.logger.info("Valid Clean.")
|
||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||
|
||
query_img = query_img.cpu().detach().numpy()
|
||
query_txt = query_txt.cpu().detach().numpy()
|
||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||
query_labels = self.query_labels.numpy()
|
||
retrieval_labels = self.retrieval_labels.numpy()
|
||
mAPi2t = cal_map(query_img, query_labels, retrieval_txt, retrieval_labels)
|
||
mAPt2i = cal_map(query_txt, query_labels, retrieval_img, retrieval_labels)
|
||
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
||
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
||
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
||
result_dict = {
|
||
'q_img': query_img,
|
||
'q_txt': query_txt,
|
||
'r_img': retrieval_img,
|
||
'r_txt': retrieval_txt,
|
||
'q_l': query_labels,
|
||
'r_l': retrieval_labels
|
||
}
|
||
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"),
|
||
result_dict)
|
||
self.logger.info(">>>>>> save all data!")
|
||
|
||
|
||
|
||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||
|
||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
query_img = query_img.cpu().detach().numpy()
|
||
query_txt = query_txt.cpu().detach().numpy()
|
||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||
query_labels = self.query_labels.numpy()
|
||
retrieval_labels = self.retrieval_labels.numpy()
|
||
|
||
result_dict = {
|
||
'q_img': query_img,
|
||
'q_txt': query_txt,
|
||
'r_img': retrieval_img,
|
||
'r_txt': retrieval_txt,
|
||
'q_l': query_labels,
|
||
'r_l': retrieval_labels
|
||
}
|
||
scio.savemat(
|
||
os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"),
|
||
result_dict)
|
||
self.logger.info(f">>>>>> save best {mode_name} data!")
|