This commit is contained in:
parent
321817c86f
commit
ad01bb4f0e
|
|
@ -8,8 +8,8 @@ import torch
|
||||||
import random
|
import random
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
# from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
class BaseDataset(Dataset):
|
||||||
|
|
||||||
|
|
@ -19,7 +19,7 @@ class BaseDataset(Dataset):
|
||||||
indexs: dict,
|
indexs: dict,
|
||||||
labels: dict,
|
labels: dict,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
tokenizer=Tokenizer(),
|
tokenizer=AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77, truncation=True),
|
||||||
maxWords=32,
|
maxWords=32,
|
||||||
imageResolution=224,
|
imageResolution=224,
|
||||||
npy=False):
|
npy=False):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
|
||||||
random_index = np.random.permutation(range(len(indexs)))
|
random_index = np.random.permutation(range(len(indexs)))
|
||||||
query_index = random_index[: query_num]
|
query_index = random_index[: query_num]
|
||||||
train_index = random_index[query_num: query_num + train_num]
|
train_index = random_index[query_num: query_num + train_num]
|
||||||
retrieval_index = random_index[query_num:]
|
retrieval_index = random_index[query_num:-100000]
|
||||||
|
|
||||||
query_indexs = indexs[query_index]
|
query_indexs = indexs[query_index]
|
||||||
query_captions = captions[query_index]
|
query_captions = captions[query_index]
|
||||||
|
|
|
||||||
6
main.py
6
main.py
|
|
@ -1,11 +1,11 @@
|
||||||
from train.text_train import Trainer
|
# from train.text_train import Trainer
|
||||||
|
from train.hash_train import Trainer
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
engine=Trainer()
|
engine=Trainer()
|
||||||
engine.test()
|
engine.test()
|
||||||
engine.train_epoch()
|
# engine.train_epoch()
|
||||||
|
|
||||||
|
|
||||||
# engine.train()
|
# engine.train()
|
||||||
|
|
|
||||||
|
|
@ -131,11 +131,16 @@ class SimpleTokenizer(object):
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def my_decode(self, tokens):
|
# def my_decode(self, tokens):
|
||||||
tokens=[item for item in tokens if item in self.decoder]
|
# print(tokens)
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
# tem_token=[]
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
# for i in tokens:
|
||||||
return text
|
# if i in self.decoder.keys():
|
||||||
|
# tem_token.append(self.decoder[i])
|
||||||
|
# print(tem_token)
|
||||||
|
# text = ''.join(tem_token)
|
||||||
|
# text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||||
|
# return text
|
||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
tokens = []
|
tokens = []
|
||||||
|
|
@ -147,3 +152,6 @@ class SimpleTokenizer(object):
|
||||||
|
|
||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
return [self.encoder[bpe_token] for bpe_token in tokens]
|
return [self.encoder[bpe_token] for bpe_token in tokens]
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(self, ids):
|
||||||
|
return [self.decoder[id] for id in ids]
|
||||||
|
|
|
||||||
|
|
@ -183,18 +183,18 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
||||||
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
||||||
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
|
pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels)
|
||||||
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
||||||
result_dict = {
|
result_dict = {
|
||||||
'adv_img': adv_img,
|
'adv_img': adv_img,
|
||||||
'r_txt': retrieval_txt,
|
'r_txt': retrieval_txt,
|
||||||
'adv_l': adv_labels,
|
'adv_l': adv_labels,
|
||||||
'r_l': retrieval_labels
|
'r_l': retrieval_labels,
|
||||||
# 'q_l':query_labels
|
# 'q_l':query_labels
|
||||||
# 'pr': pr,
|
# 'pr': pr,
|
||||||
# 'pr_t': pr_t
|
'pr_t': pr_t
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
||||||
self.logger.info(">>>>>> save all data!")
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -267,8 +267,8 @@ class Trainer(TrainBase):
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
||||||
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
||||||
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels)
|
||||||
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels)
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
||||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
||||||
result_dict = {
|
result_dict = {
|
||||||
|
|
@ -277,11 +277,11 @@ class Trainer(TrainBase):
|
||||||
'r_img': retrieval_img,
|
'r_img': retrieval_img,
|
||||||
'r_txt': retrieval_txt,
|
'r_txt': retrieval_txt,
|
||||||
'q_l': query_labels,
|
'q_l': query_labels,
|
||||||
'r_l': retrieval_labels
|
'r_l': retrieval_labels,
|
||||||
# 'pr_i2t': pr_i2t,
|
'pr_i2t': pr_i2t,
|
||||||
# 'pr_t2i': pr_t2i
|
'pr_t2i': pr_t2i
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
||||||
self.logger.info(">>>>>> save all data!")
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,14 +12,14 @@ import numpy as np
|
||||||
|
|
||||||
from .base import TrainBase
|
from .base import TrainBase
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity, find_indices
|
||||||
from utils.calc_utils import cal_map, cal_pr
|
from utils.calc_utils import cal_map, cal_pr
|
||||||
from dataset.dataloader import dataloader
|
from dataset.dataloader import dataloader
|
||||||
import clip
|
import clip
|
||||||
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
|
import re
|
||||||
from transformers import BertForMaskedLM
|
from transformers import BertForMaskedLM
|
||||||
from model.bert_tokenizer import BertTokenizer
|
from model.bert_tokenizer import BertTokenizer
|
||||||
# from transformers import BertModel
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
@ -48,13 +48,21 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
|
||||||
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
|
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
|
||||||
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
|
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
|
||||||
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
|
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
|
||||||
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '~', '·']
|
||||||
filter_words = set(filter_words)
|
filter_words = set(filter_words)
|
||||||
|
|
||||||
|
|
||||||
|
def text_filter(text):
|
||||||
|
text = re.findall(r"<|startoftext|>(.+)<|endoftext|>", text)
|
||||||
|
text = text[2]
|
||||||
|
text = re.sub(r'</w>', ' ', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class GoalFunctionStatus(object):
|
class GoalFunctionStatus(object):
|
||||||
SUCCEEDED = 0 # attack succeeded
|
SUCCEEDED = 0 # attack succeeded
|
||||||
SEARCHING = 1 # In process of searching for a success
|
SEARCHING = 1 # In process of searching for a success
|
||||||
FAILED = 2 # attack failed
|
FAILED = 2 # attack failed
|
||||||
|
|
||||||
|
|
||||||
class GoalFunctionResult(object):
|
class GoalFunctionResult(object):
|
||||||
|
|
@ -82,89 +90,29 @@ class GoalFunctionResult(object):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.text)
|
return hash(self.text)
|
||||||
|
|
||||||
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
|
|
||||||
# # substitutes L, k
|
|
||||||
# # device = mlm_model.device
|
|
||||||
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
|
|
||||||
#
|
|
||||||
# # find all possible candidates
|
|
||||||
#
|
|
||||||
# all_substitutes = []
|
|
||||||
# for i in range(substitutes.size(0)):
|
|
||||||
# if len(all_substitutes) == 0:
|
|
||||||
# lev_i = substitutes[i]
|
|
||||||
# all_substitutes = [[int(c)] for c in lev_i]
|
|
||||||
# else:
|
|
||||||
# lev_i = []
|
|
||||||
# for all_sub in all_substitutes:
|
|
||||||
# for j in substitutes[i]:
|
|
||||||
# lev_i.append(all_sub + [int(j)])
|
|
||||||
# all_substitutes = lev_i
|
|
||||||
#
|
|
||||||
# # all substitutes list of list of token-id (all candidates)
|
|
||||||
# c_loss = nn.CrossEntropyLoss(reduction='none')
|
|
||||||
# word_list = []
|
|
||||||
# # all_substitutes = all_substitutes[:24]
|
|
||||||
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
|
|
||||||
# all_substitutes = all_substitutes[:24].to(device)
|
|
||||||
# # print(substitutes.size(), all_substitutes.size())
|
|
||||||
# N, L = all_substitutes.size()
|
|
||||||
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
|
|
||||||
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
|
|
||||||
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
|
|
||||||
# _, word_list = torch.sort(ppl)
|
|
||||||
# word_list = [all_substitutes[i] for i in word_list]
|
|
||||||
# final_words = []
|
|
||||||
# for word in word_list:
|
|
||||||
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
|
|
||||||
# text = tokenizer.convert_tokens_to_string(tokens)
|
|
||||||
# final_words.append(text)
|
|
||||||
# return final_words
|
|
||||||
|
|
||||||
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
|
|
||||||
# # substitues L,k
|
|
||||||
# # from this matrix to recover a word
|
|
||||||
# words = []
|
|
||||||
# sub_len, k = substitutes.size() # sub-len, k
|
|
||||||
#
|
|
||||||
# if sub_len == 0:
|
|
||||||
# return words
|
|
||||||
#
|
|
||||||
# elif sub_len == 1:
|
|
||||||
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
|
|
||||||
# if threshold != 0 and j < threshold:
|
|
||||||
# break
|
|
||||||
# words.append(tokenizer._convert_id_to_token(int(i)))
|
|
||||||
# else:
|
|
||||||
# if use_bpe == 1:
|
|
||||||
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
|
|
||||||
# else:
|
|
||||||
# return words
|
|
||||||
# #
|
|
||||||
# # print(words)
|
|
||||||
# return words
|
|
||||||
|
|
||||||
class Trainer(TrainBase):
|
class Trainer(TrainBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rank=0):
|
rank=0):
|
||||||
args = get_args()
|
args = get_args()
|
||||||
super(Trainer, self).__init__(args, rank)
|
super(Trainer, self).__init__(args, rank)
|
||||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||||
image_mean, image_var=self.generate_mapping()
|
image_mean, image_var = self.generate_mapping()
|
||||||
self.image_mean=image_mean
|
self.image_mean = image_mean
|
||||||
self.image_var=image_var
|
self.image_var = image_var
|
||||||
self.device=rank
|
self.device = rank
|
||||||
self.clip_tokenizer=Tokenizer()
|
self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77,
|
||||||
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
|
truncation=True)
|
||||||
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
|
self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.text_encoder, do_lower_case=True)
|
||||||
|
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder).to(device)
|
||||||
self.attack_thred = self.args.attack_thred
|
self.attack_thred = self.args.attack_thred
|
||||||
# self.run()
|
# self.run()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
self.logger.info("init model.")
|
self.logger.info("init model.")
|
||||||
model_clip, preprocess = clip.load(self.args.victim, device=device)
|
model_clip, preprocess = clip.load(self.args.victim, device=device)
|
||||||
self.model= model_clip
|
self.model = model_clip
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.float()
|
self.model.float()
|
||||||
|
|
||||||
|
|
@ -174,14 +122,14 @@ class Trainer(TrainBase):
|
||||||
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
||||||
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
||||||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
||||||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
||||||
indexFile=self.args.index_file,
|
indexFile=self.args.index_file,
|
||||||
labelFile=self.args.label_file,
|
labelFile=self.args.label_file,
|
||||||
maxWords=self.args.max_words,
|
maxWords=self.args.max_words,
|
||||||
imageResolution=self.args.resolution,
|
imageResolution=self.args.resolution,
|
||||||
query_num=self.args.query_num,
|
query_num=self.args.query_num,
|
||||||
train_num=self.args.train_num,
|
train_num=self.args.train_num,
|
||||||
seed=self.args.seed)
|
seed=self.args.seed)
|
||||||
self.train_labels = train_data.get_all_label()
|
self.train_labels = train_data.get_all_label()
|
||||||
self.query_labels = query_data.get_all_label()
|
self.query_labels = query_data.get_all_label()
|
||||||
self.retrieval_labels = retrieval_data.get_all_label()
|
self.retrieval_labels = retrieval_data.get_all_label()
|
||||||
|
|
@ -189,28 +137,28 @@ class Trainer(TrainBase):
|
||||||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
self.logger.info(f"query shape: {self.query_labels.shape}")
|
||||||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
dataset=train_data,
|
dataset=train_data,
|
||||||
batch_size=self.args.batch_size,
|
batch_size=self.args.batch_size,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
self.query_loader = DataLoader(
|
self.query_loader = DataLoader(
|
||||||
dataset=query_data,
|
dataset=query_data,
|
||||||
batch_size=self.args.batch_size,
|
batch_size=self.args.batch_size,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
self.retrieval_loader = DataLoader(
|
self.retrieval_loader = DataLoader(
|
||||||
dataset=retrieval_data,
|
dataset=retrieval_data,
|
||||||
batch_size=self.args.batch_size,
|
batch_size=self.args.batch_size,
|
||||||
num_workers=self.args.num_workers,
|
num_workers=self.args.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
self.train_data=train_data
|
self.train_data = train_data
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
words = text.split(' ')
|
words = text.split(' ')
|
||||||
|
|
||||||
|
|
@ -224,8 +172,8 @@ class Trainer(TrainBase):
|
||||||
index += len(sub)
|
index += len(sub)
|
||||||
|
|
||||||
return words, sub_words, keys
|
return words, sub_words, keys
|
||||||
|
|
||||||
def get_important_scores(self, text, origin_embeds, batch_size, max_length):
|
def get_important_scores(self, text, origin_embeds, batch_size, max_length):
|
||||||
# device = origin_embeds.device
|
# device = origin_embeds.device
|
||||||
|
|
||||||
masked_words = self._get_masked(text)
|
masked_words = self._get_masked(text)
|
||||||
|
|
@ -233,17 +181,20 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
masked_embeds = []
|
masked_embeds = []
|
||||||
for i in range(0, len(masked_texts), batch_size):
|
for i in range(0, len(masked_texts), batch_size):
|
||||||
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
|
masked_text_input = self.bert_tokenizer(masked_texts[i:i + batch_size], padding='max_length',
|
||||||
|
truncation=True, max_length=max_length, return_tensors='pt').to(
|
||||||
|
device)
|
||||||
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
|
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
|
||||||
masked_embeds.append(masked_embed)
|
masked_embeds.append(masked_embed)
|
||||||
masked_embeds = torch.cat(masked_embeds, dim=0)
|
masked_embeds = torch.cat(masked_embeds, dim=0)
|
||||||
|
|
||||||
criterion = torch.nn.KLDivLoss(reduction='none')
|
criterion = torch.nn.KLDivLoss(reduction='none')
|
||||||
|
|
||||||
import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
|
import_scores = criterion(masked_embeds.log_softmax(dim=-1),
|
||||||
|
origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
|
||||||
|
|
||||||
return import_scores.sum(dim=-1)
|
return import_scores.sum(dim=-1)
|
||||||
|
|
||||||
def _get_masked(self, text):
|
def _get_masked(self, text):
|
||||||
words = text.split(' ')
|
words = text.split(' ')
|
||||||
len_text = len(words)
|
len_text = len(words)
|
||||||
|
|
@ -265,14 +216,14 @@ class Trainer(TrainBase):
|
||||||
def get_word_predictions(self, text):
|
def get_word_predictions(self, text):
|
||||||
_, _, keys = self._tokenize(text)
|
_, _, keys = self._tokenize(text)
|
||||||
|
|
||||||
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len,
|
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.args.max_words,
|
||||||
truncation=True, return_tensors="pt")
|
truncation=True, return_tensors="pt")
|
||||||
input_ids = inputs["input_ids"].to(self.device)
|
input_ids = inputs["input_ids"].to(self.device)
|
||||||
attention_mask = inputs['attention_mask']
|
attention_mask = inputs['attention_mask']
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size)
|
word_predictions = self.ref_net(input_ids)['logits'].squeeze(0) # (seq_len, vocab_size)
|
||||||
|
# print(self.ref_net(input_ids)['logits'].shape)
|
||||||
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1)
|
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.args.max_candidate, -1)
|
||||||
|
|
||||||
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
|
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
|
||||||
word_pred_scores_all = word_pred_scores_all[1:-1, :]
|
word_pred_scores_all = word_pred_scores_all[1:-1, :]
|
||||||
|
|
@ -343,55 +294,65 @@ class Trainer(TrainBase):
|
||||||
ret.append(word)
|
ret.append(word)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||||
positive_var, beta=10,temperature=0.05):
|
positive_var, beta=10, temperature=0.05):
|
||||||
trans_feature= self.clip_tokenizer(trans_texts)
|
# print(trans_texts)
|
||||||
anchor=self.model.encode_text(trans_feature)
|
trans_feature = clip.tokenize(trans_texts,context_length=77,truncate=True).to(device)
|
||||||
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code,
|
anchor = self.model.encode_text(trans_feature)
|
||||||
|
batch_size=anchor.shape[0]
|
||||||
|
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code.repeat(batch_size, 1), negetive_code.repeat(batch_size, 1),
|
||||||
distance_function=nn.CosineSimilarity(), reduction='none')
|
distance_function=nn.CosineSimilarity(), reduction='none')
|
||||||
negative_dist = (anchor - negetive_mean) ** 2 / negative_var
|
sim=F.cosine_similarity(anchor,positive_code.unsqueeze(0), dim=1, eps=1e-8).unsqueeze(1)
|
||||||
positive_dist = (anchor - positive_mean) ** 2 / positive_var
|
negative_dist = (anchor - negetive_mean) ** 2 / (negative_var+ 1e-5)
|
||||||
|
positive_dist = (anchor - positive_mean) ** 2 / (positive_var+ 1e-5)
|
||||||
negatives = torch.exp(negative_dist / temperature)
|
negatives = torch.exp(negative_dist / temperature)
|
||||||
positives = torch.exp(positive_dist / temperature)
|
positives = torch.exp(positive_dist / temperature)
|
||||||
loss = torch.log(positives / (positives + negatives)) + beta * loss1
|
loss = torch.log(positives / (positives + negatives)).mean(dim=1, keepdim=True) + beta * loss1
|
||||||
return loss
|
results = []
|
||||||
|
# print(loss.shape)
|
||||||
|
for i in range(len(trans_texts)):
|
||||||
|
if loss[i].shape[0] >1 or sim[i] >self.args.sim_threshold:
|
||||||
|
continue
|
||||||
|
results.append(GoalFunctionResult(trans_texts[i], score=loss[i], similarity=sim[i]))
|
||||||
|
return results
|
||||||
|
|
||||||
def generate_mapping(self):
|
def generate_mapping(self):
|
||||||
image_train=[]
|
image_train = []
|
||||||
label_train=[]
|
label_train = []
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
# raw_text=[self.clip_tokenizer.decode(token) for token in text]
|
# raw_text=[self.clip_tokenizer.decode(token) for token in text]
|
||||||
image=image.to(device, non_blocking=True)
|
image = image.to(device, non_blocking=True)
|
||||||
# print(self.model.vocab_size)
|
# print(self.model.vocab_size)
|
||||||
temp_image=self.model.encode_image(image)
|
temp_image = self.model.encode_image(image)
|
||||||
image_train.append(temp_image.cpu().detach().numpy())
|
image_train.append(temp_image.cpu().detach().numpy())
|
||||||
label_train.append(label.detach().numpy())
|
label_train.append(label.detach().numpy())
|
||||||
image_train=np.concatenate(image_train, axis=0)
|
image_train = np.concatenate(image_train, axis=0)
|
||||||
label_train=np.concatenate(label_train, axis=0)
|
label_train = np.concatenate(label_train, axis=0)
|
||||||
label_unipue=np.unique(label_train,axis=0)
|
label_unipue = np.unique(label_train, axis=0)
|
||||||
image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
image_centroids = np.stack(
|
||||||
image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
[image_train[find_indices(label_train, label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))],
|
||||||
|
axis=0)
|
||||||
|
image_var = np.stack(
|
||||||
|
[image_train[find_indices(label_train, label_unipue[i])].var(axis=0) for i in range(len(label_unipue))],
|
||||||
|
axis=0)
|
||||||
|
|
||||||
image_representation = {}
|
image_representation = {}
|
||||||
image_var_representation = {}
|
image_var_representation = {}
|
||||||
for i, centroid in enumerate(label_unipue):
|
for i, centroid in enumerate(label_unipue):
|
||||||
image_representation[str(centroid.astype(int))] = image_centroids[i]
|
image_representation[str(centroid.astype(int))] = image_centroids[i]
|
||||||
image_var_representation[str(centroid.astype(int))]= image_var[i]
|
image_var_representation[str(centroid.astype(int))] = image_var[i]
|
||||||
return image_representation, image_var_representation
|
return image_representation, image_var_representation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
|
||||||
positive_var, beta=10, temperature=0.05):
|
positive_var, beta=10, temperature=0.05):
|
||||||
|
# print(raw_text)
|
||||||
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
|
||||||
|
|
||||||
#clean state
|
#clean state
|
||||||
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
|
||||||
cur_result = GoalFunctionResult(raw_text)
|
cur_result = GoalFunctionResult(raw_text)
|
||||||
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
|
mask_idx = np.where(mask.cpu().numpy() == 1)[0]
|
||||||
|
|
||||||
|
|
||||||
for idx in mask_idx:
|
for idx in mask_idx:
|
||||||
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
|
predictions = word_predictions[keys[idx][0]: keys[idx][1]]
|
||||||
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
|
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
|
||||||
|
|
@ -402,7 +363,7 @@ class Trainer(TrainBase):
|
||||||
continue
|
continue
|
||||||
# loss function
|
# loss function
|
||||||
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
|
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
|
||||||
positive_mean,positive_var, beta,temperature)
|
positive_mean, positive_var, beta, temperature)
|
||||||
results = sorted(results, key=lambda x: x.score, reverse=True)
|
results = sorted(results, key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
if len(results) > 0 and results[0].score > cur_result.score:
|
if len(results) > 0 and results[0].score > cur_result.score:
|
||||||
|
|
@ -427,63 +388,62 @@ class Trainer(TrainBase):
|
||||||
cur_result.status = GoalFunctionStatus.FAILED
|
cur_result.status = GoalFunctionStatus.FAILED
|
||||||
return cur_result
|
return cur_result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(self):
|
def train_epoch(self):
|
||||||
# self.change_state(mode="valid")
|
# self.change_state(mode="valid")
|
||||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||||
all_loss = 0
|
all_loss = 0
|
||||||
times = 0
|
times = 0
|
||||||
adv_codes=[]
|
adv_codes = []
|
||||||
adv_label=[]
|
adv_label = []
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
times += 1
|
times += 1
|
||||||
print(times)
|
print(times)
|
||||||
image.float()
|
image.float()
|
||||||
|
|
||||||
raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
|
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
negetive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||||
negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
negative_var = np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||||
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
negetive_mean = torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||||||
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
negative_var = torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
||||||
negetive_code=self.model.encode_image(image)
|
negetive_code = self.model.encode_image(image)
|
||||||
|
|
||||||
#targeted sample
|
#targeted sample
|
||||||
np.random.seed(times)
|
np.random.seed(times)
|
||||||
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
|
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
|
||||||
target_dataset = data.Subset(self.train_data, select_index)
|
target_dataset = data.Subset(self.train_data, select_index)
|
||||||
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
|
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
|
||||||
target_image, _, target_label, _ = next(iter(target_subset))
|
target_image, _, target_label, _ = next(iter(target_subset))
|
||||||
target_image=target_image.to(self.rank, non_blocking=True)
|
target_image = target_image.to(self.rank, non_blocking=True)
|
||||||
positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
positive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||||||
positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
positive_var = np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||||||
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
|
positive_mean = torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
|
||||||
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
|
positive_var = torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
|
||||||
positive_code=self.model.encode_image(target_image)
|
positive_code = self.model.encode_image(target_image)
|
||||||
|
# print(self.clip_tokenizer.my_encode('This day is good!'))
|
||||||
|
raw_text = [self.clip_tokenizer.convert_ids_to_tokens(token.cpu()) for token in text]
|
||||||
final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var,
|
raw_text = [text_filter(self.clip_tokenizer.convert_tokens_to_string(txt)) for txt in raw_text]
|
||||||
positive_code,positive_mean,positive_var)
|
final_texts=[]
|
||||||
final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True)
|
for i in range(self.args.batch_size):
|
||||||
adv_code=self.model.encode_text(final_text)
|
adv_txt=self.target_adv( raw_text[i], negetive_code[i], negetive_mean[i], negative_var[i],
|
||||||
|
positive_code[i], positive_mean[i], positive_var[i])
|
||||||
|
final_texts.append(adv_txt.text)
|
||||||
|
# final_adverse = self.target_adv( raw_text, negetive_code, negetive_mean, negative_var,
|
||||||
|
# positive_code, positive_mean, positive_var)
|
||||||
|
final_text = clip.tokenize(final_texts,context_length=77,truncate=True).to(self.rank, non_blocking=True)
|
||||||
|
adv_code = self.model.encode_text(final_text)
|
||||||
adv_codes.append(adv_code.cpu().detach().numpy())
|
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||||
adv_label.append(target_label.numpy())
|
adv_label.append(target_label.numpy())
|
||||||
adv_img=np.concatenate(adv_codes)
|
adv_img = np.concatenate(adv_codes)
|
||||||
adv_labels=np.concatenate(adv_label)
|
adv_labels = np.concatenate(adv_label)
|
||||||
|
|
||||||
|
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
|
|
||||||
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
|
|
||||||
|
|
||||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
mAP_t = cal_map(adv_img, adv_labels, retrieval_txt, retrieval_labels)
|
||||||
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
||||||
result_dict = {
|
result_dict = {
|
||||||
'adv_img': adv_img,
|
'adv_img': adv_img,
|
||||||
|
|
@ -494,13 +454,9 @@ class Trainer(TrainBase):
|
||||||
# 'pr': pr,
|
# 'pr': pr,
|
||||||
# 'pr_t': pr_t
|
# 'pr_t': pr_t
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"),
|
||||||
|
result_dict)
|
||||||
self.logger.info(">>>>>> save all data!")
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self.logger.info("Start train.")
|
self.logger.info("Start train.")
|
||||||
|
|
@ -510,9 +466,8 @@ class Trainer(TrainBase):
|
||||||
self.valid(epoch)
|
self.valid(epoch)
|
||||||
self.save_model(epoch)
|
self.save_model(epoch)
|
||||||
|
|
||||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
self.logger.info(
|
||||||
|
f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||||
|
|
||||||
|
|
||||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||||
|
|
||||||
|
|
@ -539,25 +494,19 @@ class Trainer(TrainBase):
|
||||||
text_features = self.model.encode_text(text)
|
text_features = self.model.encode_text(text)
|
||||||
img_buffer[index, :] = image_feature.detach()
|
img_buffer[index, :] = image_feature.detach()
|
||||||
text_buffer[index, :] = text_features.detach()
|
text_buffer[index, :] = text_features.detach()
|
||||||
|
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
return img_buffer, text_buffer # img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
|
|
||||||
|
def valid_attack(self, adv_images, texts, adv_labels):
|
||||||
|
|
||||||
|
|
||||||
def valid_attack(self,adv_images, texts, adv_labels):
|
|
||||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test(self, mode_name="i2t"):
|
def test(self, mode_name="i2t"):
|
||||||
self.logger.info("Valid Clean.")
|
self.logger.info("Valid Clean.")
|
||||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
|
|
||||||
|
|
||||||
query_img = query_img.cpu().detach().numpy()
|
query_img = query_img.cpu().detach().numpy()
|
||||||
query_txt = query_txt.cpu().detach().numpy()
|
query_txt = query_txt.cpu().detach().numpy()
|
||||||
|
|
@ -565,8 +514,8 @@ class Trainer(TrainBase):
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
query_labels = self.query_labels.numpy()
|
query_labels = self.query_labels.numpy()
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
mAPi2t = cal_map(query_img, query_labels, retrieval_txt, retrieval_labels)
|
||||||
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
mAPt2i = cal_map(query_txt, query_labels, retrieval_img, retrieval_labels)
|
||||||
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
||||||
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
||||||
|
|
@ -579,31 +528,11 @@ class Trainer(TrainBase):
|
||||||
'q_l': query_labels,
|
'q_l': query_labels,
|
||||||
'r_l': retrieval_labels
|
'r_l': retrieval_labels
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"),
|
||||||
|
result_dict)
|
||||||
self.logger.info(">>>>>> save all data!")
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
# def valid(self, epoch):
|
|
||||||
# self.logger.info("Valid.")
|
|
||||||
# self.change_state(mode="valid")
|
|
||||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
|
||||||
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
||||||
# # print("get all code")
|
|
||||||
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# # print("map map")
|
|
||||||
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# if self.max_mapi2t < mAPi2t:
|
|
||||||
# self.best_epoch_i = epoch
|
|
||||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
|
||||||
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
|
||||||
# if self.max_mapt2i < mAPt2i:
|
|
||||||
# self.best_epoch_t = epoch
|
|
||||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
|
||||||
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
|
||||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
|
||||||
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
|
||||||
|
|
||||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||||
|
|
||||||
|
|
@ -625,7 +554,7 @@ class Trainer(TrainBase):
|
||||||
'q_l': query_labels,
|
'q_l': query_labels,
|
||||||
'r_l': retrieval_labels
|
'r_l': retrieval_labels
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
scio.savemat(
|
||||||
|
os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"),
|
||||||
|
result_dict)
|
||||||
self.logger.info(f">>>>>> save best {mode_name} data!")
|
self.logger.info(f">>>>>> save best {mode_name} data!")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,20 +9,22 @@ def get_args():
|
||||||
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
|
parser.add_argument("--save-dir", type=str, default="./result/64-bit")
|
||||||
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
|
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
|
||||||
parser.add_argument("--pretrained", type=str, default="")
|
parser.add_argument("--pretrained", type=str, default="")
|
||||||
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]")
|
parser.add_argument("--dataset", type=str, default="coco", help="choise from [coco, mirflckr25k, nuswide]")
|
||||||
parser.add_argument("--index-file", type=str, default="index.mat")
|
parser.add_argument("--index-file", type=str, default="index.mat")
|
||||||
parser.add_argument("--caption-file", type=str, default="caption.mat")
|
parser.add_argument("--caption-file", type=str, default="caption.mat")
|
||||||
parser.add_argument("--label-file", type=str, default="label.mat")
|
parser.add_argument("--label-file", type=str, default="label.mat")
|
||||||
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
|
||||||
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
|
||||||
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
|
||||||
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
|
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
|
||||||
parser.add_argument("--topk", type=int, default=10)
|
parser.add_argument("--topk", type=int, default=10)
|
||||||
parser.add_argument("--num-perturbation", type=int, default=3)
|
parser.add_argument("--num-perturbation", type=int, default=3)
|
||||||
parser.add_argument("--txt-dim", type=int, default=1024)
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=512)
|
parser.add_argument("--output-dim", type=int, default=1024)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--max-words", type=int, default=77)
|
parser.add_argument("--max-words", type=int, default=77)
|
||||||
|
parser.add_argument("--max-candidate", type=int, default=7)
|
||||||
|
parser.add_argument("--enable-bpe", type=bool, default=False)
|
||||||
parser.add_argument("--resolution", type=int, default=224)
|
parser.add_argument("--resolution", type=int, default=224)
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
parser.add_argument("--num-workers", type=int, default=4)
|
parser.add_argument("--num-workers", type=int, default=4)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue