This commit is contained in:
Li Wenyun 2024-07-04 10:33:33 +08:00
parent 321817c86f
commit ad01bb4f0e
7 changed files with 186 additions and 247 deletions

View File

@ -8,8 +8,8 @@ import torch
import random import random
from PIL import Image from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from model.simple_tokenizer import SimpleTokenizer as Tokenizer # from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import AutoTokenizer
class BaseDataset(Dataset): class BaseDataset(Dataset):
@ -19,7 +19,7 @@ class BaseDataset(Dataset):
indexs: dict, indexs: dict,
labels: dict, labels: dict,
is_train=True, is_train=True,
tokenizer=Tokenizer(), tokenizer=AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77, truncation=True),
maxWords=32, maxWords=32,
imageResolution=224, imageResolution=224,
npy=False): npy=False):

View File

@ -10,7 +10,7 @@ def split_data(captions, indexs, labels, query_num=5000, train_num=10000, seed=N
random_index = np.random.permutation(range(len(indexs))) random_index = np.random.permutation(range(len(indexs)))
query_index = random_index[: query_num] query_index = random_index[: query_num]
train_index = random_index[query_num: query_num + train_num] train_index = random_index[query_num: query_num + train_num]
retrieval_index = random_index[query_num:] retrieval_index = random_index[query_num:-100000]
query_indexs = indexs[query_index] query_indexs = indexs[query_index]
query_captions = captions[query_index] query_captions = captions[query_index]

View File

@ -1,11 +1,11 @@
from train.text_train import Trainer # from train.text_train import Trainer
from train.hash_train import Trainer
if __name__ == "__main__": if __name__ == "__main__":
engine=Trainer() engine=Trainer()
engine.test() engine.test()
engine.train_epoch() # engine.train_epoch()
# engine.train() # engine.train()

View File

@ -131,11 +131,16 @@ class SimpleTokenizer(object):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text return text
def my_decode(self, tokens): # def my_decode(self, tokens):
tokens=[item for item in tokens if item in self.decoder] # print(tokens)
text = ''.join([self.decoder[token] for token in tokens]) # tem_token=[]
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') # for i in tokens:
return text # if i in self.decoder.keys():
# tem_token.append(self.decoder[i])
# print(tem_token)
# text = ''.join(tem_token)
# text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
# return text
def tokenize(self, text): def tokenize(self, text):
tokens = [] tokens = []
@ -147,3 +152,6 @@ class SimpleTokenizer(object):
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
return [self.encoder[bpe_token] for bpe_token in tokens] return [self.encoder[bpe_token] for bpe_token in tokens]
def convert_ids_to_tokens(self, ids):
return [self.decoder[id] for id in ids]

View File

@ -183,18 +183,18 @@ class Trainer(TrainBase):
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels) # pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels) pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}") self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = { result_dict = {
'adv_img': adv_img, 'adv_img': adv_img,
'r_txt': retrieval_txt, 'r_txt': retrieval_txt,
'adv_l': adv_labels, 'adv_l': adv_labels,
'r_l': retrieval_labels 'r_l': retrieval_labels,
# 'q_l':query_labels # 'q_l':query_labels
# 'pr': pr, # 'pr': pr,
# 'pr_t': pr_t 'pr_t': pr_t
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!") self.logger.info(">>>>>> save all data!")
@ -267,8 +267,8 @@ class Trainer(TrainBase):
retrieval_labels = self.retrieval_labels.numpy() retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels)
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}") self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
result_dict = { result_dict = {
@ -277,11 +277,11 @@ class Trainer(TrainBase):
'r_img': retrieval_img, 'r_img': retrieval_img,
'r_txt': retrieval_txt, 'r_txt': retrieval_txt,
'q_l': query_labels, 'q_l': query_labels,
'r_l': retrieval_labels 'r_l': retrieval_labels,
# 'pr_i2t': pr_i2t, 'pr_i2t': pr_i2t,
# 'pr_t2i': pr_t2i 'pr_t2i': pr_t2i
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!") self.logger.info(">>>>>> save all data!")

View File

@ -12,14 +12,14 @@ import numpy as np
from .base import TrainBase from .base import TrainBase
from torch.nn import functional as F from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity, find_indices
from utils.calc_utils import cal_map, cal_pr from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader from dataset.dataloader import dataloader
import clip import clip
from model.simple_tokenizer import SimpleTokenizer as Tokenizer import re
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
from model.bert_tokenizer import BertTokenizer from model.bert_tokenizer import BertTokenizer
# from transformers import BertModel from transformers import AutoTokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@ -48,13 +48,21 @@ filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again',
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've", "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!', 'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·'] '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words) filter_words = set(filter_words)
def text_filter(text):
text = re.findall(r"<|startoftext|>(.+)<|endoftext|>", text)
text = text[2]
text = re.sub(r'</w>', ' ', text)
return text
class GoalFunctionStatus(object): class GoalFunctionStatus(object):
SUCCEEDED = 0 # attack succeeded SUCCEEDED = 0 # attack succeeded
SEARCHING = 1 # In process of searching for a success SEARCHING = 1 # In process of searching for a success
FAILED = 2 # attack failed FAILED = 2 # attack failed
class GoalFunctionResult(object): class GoalFunctionResult(object):
@ -82,89 +90,29 @@ class GoalFunctionResult(object):
def __hash__(self): def __hash__(self):
return hash(self.text) return hash(self.text)
# def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# # substitutes L, k
# # device = mlm_model.device
# substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
#
# # find all possible candidates
#
# all_substitutes = []
# for i in range(substitutes.size(0)):
# if len(all_substitutes) == 0:
# lev_i = substitutes[i]
# all_substitutes = [[int(c)] for c in lev_i]
# else:
# lev_i = []
# for all_sub in all_substitutes:
# for j in substitutes[i]:
# lev_i.append(all_sub + [int(j)])
# all_substitutes = lev_i
#
# # all substitutes list of list of token-id (all candidates)
# c_loss = nn.CrossEntropyLoss(reduction='none')
# word_list = []
# # all_substitutes = all_substitutes[:24]
# all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
# all_substitutes = all_substitutes[:24].to(device)
# # print(substitutes.size(), all_substitutes.size())
# N, L = all_substitutes.size()
# word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
# ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
# ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
# _, word_list = torch.sort(ppl)
# word_list = [all_substitutes[i] for i in word_list]
# final_words = []
# for word in word_list:
# tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
# text = tokenizer.convert_tokens_to_string(tokens)
# final_words.append(text)
# return final_words
# def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
# # substitues L,k
# # from this matrix to recover a word
# words = []
# sub_len, k = substitutes.size() # sub-len, k
#
# if sub_len == 0:
# return words
#
# elif sub_len == 1:
# for (i, j) in zip(substitutes[0], substitutes_score[0]):
# if threshold != 0 and j < threshold:
# break
# words.append(tokenizer._convert_id_to_token(int(i)))
# else:
# if use_bpe == 1:
# words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
# else:
# return words
# #
# # print(words)
# return words
class Trainer(TrainBase): class Trainer(TrainBase):
def __init__(self, def __init__(self,
rank=0): rank=0):
args = get_args() args = get_args()
super(Trainer, self).__init__(args, rank) super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
image_mean, image_var=self.generate_mapping() image_mean, image_var = self.generate_mapping()
self.image_mean=image_mean self.image_mean = image_mean
self.image_var=image_var self.image_var = image_var
self.device=rank self.device = rank
self.clip_tokenizer=Tokenizer() self.clip_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32", model_max_length=77,
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True) truncation=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder) self.bert_tokenizer = BertTokenizer.from_pretrained(self.args.text_encoder, do_lower_case=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder).to(device)
self.attack_thred = self.args.attack_thred self.attack_thred = self.args.attack_thred
# self.run() # self.run()
def _init_model(self): def _init_model(self):
self.logger.info("init model.") self.logger.info("init model.")
model_clip, preprocess = clip.load(self.args.victim, device=device) model_clip, preprocess = clip.load(self.args.victim, device=device)
self.model= model_clip self.model = model_clip
self.model.eval() self.model.eval()
self.model.float() self.model.float()
@ -174,14 +122,14 @@ class Trainer(TrainBase):
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file, indexFile=self.args.index_file,
labelFile=self.args.label_file, labelFile=self.args.label_file,
maxWords=self.args.max_words, maxWords=self.args.max_words,
imageResolution=self.args.resolution, imageResolution=self.args.resolution,
query_num=self.args.query_num, query_num=self.args.query_num,
train_num=self.args.train_num, train_num=self.args.train_num,
seed=self.args.seed) seed=self.args.seed)
self.train_labels = train_data.get_all_label() self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label() self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label()
@ -189,28 +137,28 @@ class Trainer(TrainBase):
self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader( self.train_loader = DataLoader(
dataset=train_data, dataset=train_data,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
pin_memory=True, pin_memory=True,
shuffle=True shuffle=True
) )
self.query_loader = DataLoader( self.query_loader = DataLoader(
dataset=query_data, dataset=query_data,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
pin_memory=True, pin_memory=True,
shuffle=True shuffle=True
) )
self.retrieval_loader = DataLoader( self.retrieval_loader = DataLoader(
dataset=retrieval_data, dataset=retrieval_data,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
pin_memory=True, pin_memory=True,
shuffle=True shuffle=True
) )
self.train_data=train_data self.train_data = train_data
def _tokenize(self, text): def _tokenize(self, text):
words = text.split(' ') words = text.split(' ')
@ -224,8 +172,8 @@ class Trainer(TrainBase):
index += len(sub) index += len(sub)
return words, sub_words, keys return words, sub_words, keys
def get_important_scores(self, text, origin_embeds, batch_size, max_length): def get_important_scores(self, text, origin_embeds, batch_size, max_length):
# device = origin_embeds.device # device = origin_embeds.device
masked_words = self._get_masked(text) masked_words = self._get_masked(text)
@ -233,17 +181,20 @@ class Trainer(TrainBase):
masked_embeds = [] masked_embeds = []
for i in range(0, len(masked_texts), batch_size): for i in range(0, len(masked_texts), batch_size):
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device) masked_text_input = self.bert_tokenizer(masked_texts[i:i + batch_size], padding='max_length',
truncation=True, max_length=max_length, return_tensors='pt').to(
device)
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask) masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
masked_embeds.append(masked_embed) masked_embeds.append(masked_embed)
masked_embeds = torch.cat(masked_embeds, dim=0) masked_embeds = torch.cat(masked_embeds, dim=0)
criterion = torch.nn.KLDivLoss(reduction='none') criterion = torch.nn.KLDivLoss(reduction='none')
import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1)) import_scores = criterion(masked_embeds.log_softmax(dim=-1),
origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
return import_scores.sum(dim=-1) return import_scores.sum(dim=-1)
def _get_masked(self, text): def _get_masked(self, text):
words = text.split(' ') words = text.split(' ')
len_text = len(words) len_text = len(words)
@ -265,14 +216,14 @@ class Trainer(TrainBase):
def get_word_predictions(self, text): def get_word_predictions(self, text):
_, _, keys = self._tokenize(text) _, _, keys = self._tokenize(text)
inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.max_text_len, inputs = self.bert_tokenizer.encode_plus(text, add_special_tokens=True, max_length=self.args.max_words,
truncation=True, return_tensors="pt") truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device) input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs['attention_mask'] attention_mask = inputs['attention_mask']
with torch.no_grad(): with torch.no_grad():
word_predictions = self.ref_net(input_ids)['logits'].squeeze() # (seq_len, vocab_size) word_predictions = self.ref_net(input_ids)['logits'].squeeze(0) # (seq_len, vocab_size)
# print(self.ref_net(input_ids)['logits'].shape)
word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.max_candidate, -1) word_pred_scores_all, word_predictions = torch.topk(word_predictions, self.args.max_candidate, -1)
word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP] word_predictions = word_predictions[1:-1, :] # remove [CLS] and [SEP]
word_pred_scores_all = word_pred_scores_all[1:-1, :] word_pred_scores_all = word_pred_scores_all[1:-1, :]
@ -343,55 +294,65 @@ class Trainer(TrainBase):
ret.append(word) ret.append(word)
return ret return ret
def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, def get_goal_results(self, trans_texts, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var, beta=10,temperature=0.05): positive_var, beta=10, temperature=0.05):
trans_feature= self.clip_tokenizer(trans_texts) # print(trans_texts)
anchor=self.model.encode_text(trans_feature) trans_feature = clip.tokenize(trans_texts,context_length=77,truncate=True).to(device)
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code, negetive_code, anchor = self.model.encode_text(trans_feature)
batch_size=anchor.shape[0]
loss1 = F.triplet_margin_with_distance_loss(anchor, positive_code.repeat(batch_size, 1), negetive_code.repeat(batch_size, 1),
distance_function=nn.CosineSimilarity(), reduction='none') distance_function=nn.CosineSimilarity(), reduction='none')
negative_dist = (anchor - negetive_mean) ** 2 / negative_var sim=F.cosine_similarity(anchor,positive_code.unsqueeze(0), dim=1, eps=1e-8).unsqueeze(1)
positive_dist = (anchor - positive_mean) ** 2 / positive_var negative_dist = (anchor - negetive_mean) ** 2 / (negative_var+ 1e-5)
positive_dist = (anchor - positive_mean) ** 2 / (positive_var+ 1e-5)
negatives = torch.exp(negative_dist / temperature) negatives = torch.exp(negative_dist / temperature)
positives = torch.exp(positive_dist / temperature) positives = torch.exp(positive_dist / temperature)
loss = torch.log(positives / (positives + negatives)) + beta * loss1 loss = torch.log(positives / (positives + negatives)).mean(dim=1, keepdim=True) + beta * loss1
return loss results = []
# print(loss.shape)
for i in range(len(trans_texts)):
if loss[i].shape[0] >1 or sim[i] >self.args.sim_threshold:
continue
results.append(GoalFunctionResult(trans_texts[i], score=loss[i], similarity=sim[i]))
return results
def generate_mapping(self): def generate_mapping(self):
image_train=[] image_train = []
label_train=[] label_train = []
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
# raw_text=[self.clip_tokenizer.decode(token) for token in text] # raw_text=[self.clip_tokenizer.decode(token) for token in text]
image=image.to(device, non_blocking=True) image = image.to(device, non_blocking=True)
# print(self.model.vocab_size) # print(self.model.vocab_size)
temp_image=self.model.encode_image(image) temp_image = self.model.encode_image(image)
image_train.append(temp_image.cpu().detach().numpy()) image_train.append(temp_image.cpu().detach().numpy())
label_train.append(label.detach().numpy()) label_train.append(label.detach().numpy())
image_train=np.concatenate(image_train, axis=0) image_train = np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0) label_train = np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0) label_unipue = np.unique(label_train, axis=0)
image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) image_centroids = np.stack(
image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) [image_train[find_indices(label_train, label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))],
axis=0)
image_var = np.stack(
[image_train[find_indices(label_train, label_unipue[i])].var(axis=0) for i in range(len(label_unipue))],
axis=0)
image_representation = {} image_representation = {}
image_var_representation = {} image_var_representation = {}
for i, centroid in enumerate(label_unipue): for i, centroid in enumerate(label_unipue):
image_representation[str(centroid.astype(int))] = image_centroids[i] image_representation[str(centroid.astype(int))] = image_centroids[i]
image_var_representation[str(centroid.astype(int))]= image_var[i] image_var_representation[str(centroid.astype(int))] = image_var[i]
return image_representation, image_var_representation return image_representation, image_var_representation
def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean, def target_adv(self, raw_text, negetive_code, negetive_mean, negative_var, positive_code, positive_mean,
positive_var, beta=10, temperature=0.05): positive_var, beta=10, temperature=0.05):
# print(raw_text)
keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text) keys, word_predictions, word_pred_scores_all, mask = self.get_word_predictions(raw_text)
#clean state #clean state
# clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask) # clean_embeds=self.ref_net(bert_inputs.input_ids, attention_mask=bert_inputs.attention_mask)
cur_result = GoalFunctionResult(raw_text) cur_result = GoalFunctionResult(raw_text)
mask_idx = np.where(mask.cpu().numpy() == 1)[0] mask_idx = np.where(mask.cpu().numpy() == 1)[0]
for idx in mask_idx: for idx in mask_idx:
predictions = word_predictions[keys[idx][0]: keys[idx][1]] predictions = word_predictions[keys[idx][0]: keys[idx][1]]
predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]] predictions_socre = word_pred_scores_all[keys[idx][0]: keys[idx][1]]
@ -402,7 +363,7 @@ class Trainer(TrainBase):
continue continue
# loss function # loss function
results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code, results = self.get_goal_results(trans_texts, negetive_code, negetive_mean, negative_var, positive_code,
positive_mean,positive_var, beta,temperature) positive_mean, positive_var, beta, temperature)
results = sorted(results, key=lambda x: x.score, reverse=True) results = sorted(results, key=lambda x: x.score, reverse=True)
if len(results) > 0 and results[0].score > cur_result.score: if len(results) > 0 and results[0].score > cur_result.score:
@ -427,63 +388,62 @@ class Trainer(TrainBase):
cur_result.status = GoalFunctionStatus.FAILED cur_result.status = GoalFunctionStatus.FAILED
return cur_result return cur_result
def train_epoch(self): def train_epoch(self):
# self.change_state(mode="valid") # self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
all_loss = 0 all_loss = 0
times = 0 times = 0
adv_codes=[] adv_codes = []
adv_label=[] adv_label = []
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
self.global_step += 1 self.global_step += 1
times += 1 times += 1
print(times) print(times)
image.float() image.float()
raw_text=[self.clip_tokenizer.my_decode(token) for token in text]
image = image.to(self.rank, non_blocking=True) image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var = np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negetive_mean = torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negative_var = torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_image(image) negetive_code = self.model.encode_image(image)
#targeted sample #targeted sample
np.random.seed(times) np.random.seed(times)
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
target_dataset = data.Subset(self.train_data, select_index) target_dataset = data.Subset(self.train_data, select_index)
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
target_image, _, target_label, _ = next(iter(target_subset)) target_image, _, target_label, _ = next(iter(target_subset))
target_image=target_image.to(self.rank, non_blocking=True) target_image = target_image.to(self.rank, non_blocking=True)
positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_mean = np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) positive_var = np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) positive_mean = torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) positive_var = torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
positive_code=self.model.encode_image(target_image) positive_code = self.model.encode_image(target_image)
# print(self.clip_tokenizer.my_encode('This day is good!'))
raw_text = [self.clip_tokenizer.convert_ids_to_tokens(token.cpu()) for token in text]
final_adverse=self.target_adv(text, raw_text,negetive_code,negetive_mean,negative_var, raw_text = [text_filter(self.clip_tokenizer.convert_tokens_to_string(txt)) for txt in raw_text]
positive_code,positive_mean,positive_var) final_texts=[]
final_text=self.clip_tokenizer.tokenize(final_adverse.text).to(self.rank, non_blocking=True) for i in range(self.args.batch_size):
adv_code=self.model.encode_text(final_text) adv_txt=self.target_adv( raw_text[i], negetive_code[i], negetive_mean[i], negative_var[i],
positive_code[i], positive_mean[i], positive_var[i])
final_texts.append(adv_txt.text)
# final_adverse = self.target_adv( raw_text, negetive_code, negetive_mean, negative_var,
# positive_code, positive_mean, positive_var)
final_text = clip.tokenize(final_texts,context_length=77,truncate=True).to(self.rank, non_blocking=True)
adv_code = self.model.encode_text(final_text)
adv_codes.append(adv_code.cpu().detach().numpy()) adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy()) adv_label.append(target_label.numpy())
adv_img=np.concatenate(adv_codes) adv_img = np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label) adv_labels = np.concatenate(adv_label)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy() retrieval_labels = self.retrieval_labels.numpy()
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) mAP_t = cal_map(adv_img, adv_labels, retrieval_txt, retrieval_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}") self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = { result_dict = {
'adv_img': adv_img, 'adv_img': adv_img,
@ -494,13 +454,9 @@ class Trainer(TrainBase):
# 'pr': pr, # 'pr': pr,
# 'pr_t': pr_t # 'pr_t': pr_t
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict) scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"),
result_dict)
self.logger.info(">>>>>> save all data!") self.logger.info(">>>>>> save all data!")
def train(self): def train(self):
self.logger.info("Start train.") self.logger.info("Start train.")
@ -510,9 +466,8 @@ class Trainer(TrainBase):
self.valid(epoch) self.valid(epoch)
self.save_model(epoch) self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") self.logger.info(
f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def make_hash_code(self, code: list) -> torch.Tensor: def make_hash_code(self, code: list) -> torch.Tensor:
@ -539,25 +494,19 @@ class Trainer(TrainBase):
text_features = self.model.encode_text(text) text_features = self.model.encode_text(text)
img_buffer[index, :] = image_feature.detach() img_buffer[index, :] = image_feature.detach()
text_buffer[index, :] = text_features.detach() text_buffer[index, :] = text_features.detach()
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) return img_buffer, text_buffer # img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self, adv_images, texts, adv_labels):
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"): def test(self, mode_name="i2t"):
self.logger.info("Valid Clean.") self.logger.info("Valid Clean.")
save_dir = os.path.join(self.args.save_dir, "PR_cruve") save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
query_img = query_img.cpu().detach().numpy() query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy()
@ -565,8 +514,8 @@ class Trainer(TrainBase):
retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy() query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy() retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels) mAPi2t = cal_map(query_img, query_labels, retrieval_txt, retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels) mAPt2i = cal_map(query_txt, query_labels, retrieval_img, retrieval_labels)
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels) # pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels) # pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t) self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
@ -579,31 +528,11 @@ class Trainer(TrainBase):
'q_l': query_labels, 'q_l': query_labels,
'r_l': retrieval_labels 'r_l': retrieval_labels
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"),
result_dict)
self.logger.info(">>>>>> save all data!") self.logger.info(">>>>>> save all data!")
# def valid(self, epoch):
# self.logger.info("Valid.")
# self.change_state(mode="valid")
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# # print("get all code")
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# # print("map map")
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# if self.max_mapi2t < mAPi2t:
# self.best_epoch_i = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
# if self.max_mapt2i < mAPt2i:
# self.best_epoch_t = epoch
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
@ -625,7 +554,7 @@ class Trainer(TrainBase):
'q_l': query_labels, 'q_l': query_labels,
'r_l': retrieval_labels 'r_l': retrieval_labels
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) scio.savemat(
os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"),
result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!") self.logger.info(f">>>>>> save best {mode_name} data!")

View File

@ -9,20 +9,22 @@ def get_args():
parser.add_argument("--save-dir", type=str, default="./result/64-bit") parser.add_argument("--save-dir", type=str, default="./result/64-bit")
parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.") parser.add_argument("--clip-path", type=str, default="./ViT-B-32.pt", help="pretrained clip path.")
parser.add_argument("--pretrained", type=str, default="") parser.add_argument("--pretrained", type=str, default="")
parser.add_argument("--dataset", type=str, default="flickr25k", help="choise from [coco, mirflckr25k, nuswide]") parser.add_argument("--dataset", type=str, default="coco", help="choise from [coco, mirflckr25k, nuswide]")
parser.add_argument("--index-file", type=str, default="index.mat") parser.add_argument("--index-file", type=str, default="index.mat")
parser.add_argument("--caption-file", type=str, default="caption.mat") parser.add_argument("--caption-file", type=str, default="caption.mat")
parser.add_argument("--label-file", type=str, default="label.mat") parser.add_argument("--label-file", type=str, default="label.mat")
parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]") parser.add_argument("--similarity-function", type=str, default="euclidean", help="choise form [cosine, euclidean]")
parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]") parser.add_argument("--loss-type", type=str, default="l2", help="choise form [l1, l2]")
parser.add_argument('--victim', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101']) parser.add_argument('--victim', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
parser.add_argument("--text_encoder", type=str, default="bert-base-uncased") parser.add_argument("--text_encoder", type=str, default="bert-base-uncased")
parser.add_argument("--topk", type=int, default=10) parser.add_argument("--topk", type=int, default=10)
parser.add_argument("--num-perturbation", type=int, default=3) parser.add_argument("--num-perturbation", type=int, default=3)
parser.add_argument("--txt-dim", type=int, default=1024) parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=512) parser.add_argument("--output-dim", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--max-candidate", type=int, default=7)
parser.add_argument("--enable-bpe", type=bool, default=False)
parser.add_argument("--resolution", type=int, default=224) parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4)