advclip/train/text_train.py

474 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.nn.modules import loss
# from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
import scipy.io as scio
import numpy as np
from .base import TrainBase
from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
import clip
import copy
from model.bert_tokenizer import BertTokenizer
from model.simple_tokenizer import SimpleTokenizer as Tokenizer
from transformers import BertForMaskedLM
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
"didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
"mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
"won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
'@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '', '·']
filter_words = set(filter_words)
def get_bpe_substitues(substitutes, tokenizer, mlm_model):
# substitutes L, k
# device = mlm_model.device
substitutes = substitutes[0:12, 0:4] # maximum BPE candidates
# find all possible candidates
all_substitutes = []
for i in range(substitutes.size(0)):
if len(all_substitutes) == 0:
lev_i = substitutes[i]
all_substitutes = [[int(c)] for c in lev_i]
else:
lev_i = []
for all_sub in all_substitutes:
for j in substitutes[i]:
lev_i.append(all_sub + [int(j)])
all_substitutes = lev_i
# all substitutes list of list of token-id (all candidates)
c_loss = nn.CrossEntropyLoss(reduction='none')
word_list = []
# all_substitutes = all_substitutes[:24]
all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
all_substitutes = all_substitutes[:24].to(device)
# print(substitutes.size(), all_substitutes.size())
N, L = all_substitutes.size()
word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1)) # [ N*L ]
ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
_, word_list = torch.sort(ppl)
word_list = [all_substitutes[i] for i in word_list]
final_words = []
for word in word_list:
tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
text = tokenizer.convert_tokens_to_string(tokens)
final_words.append(text)
return final_words
def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
# substitues L,k
# from this matrix to recover a word
words = []
sub_len, k = substitutes.size() # sub-len, k
if sub_len == 0:
return words
elif sub_len == 1:
for (i, j) in zip(substitutes[0], substitutes_score[0]):
if threshold != 0 and j < threshold:
break
words.append(tokenizer._convert_id_to_token(int(i)))
else:
if use_bpe == 1:
words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
else:
return words
#
# print(words)
return words
class Trainer(TrainBase):
def __init__(self,
rank=0):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
image_mean, image_var=self.generate_mapping()
self.image_mean=image_mean
self.image_var=image_var
self.device=rank
self.clip_tokenizer=Tokenizer()
self.bert_tokenizer=BertTokenizer.from_pretrained(self.args.text_encoder,do_lower_case=True)
self.ref_net = BertForMaskedLM.from_pretrained(self.args.text_encoder)
# self.run()
def _init_model(self):
self.logger.info("init model.")
model_clip, preprocess = clip.load(self.args.victim, device=device)
self.model= model_clip
self.model.eval()
self.model.float()
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.train_data=train_data
def generate_mapping(self):
image_train=[]
label_train=[]
for image, text, label, index in self.train_loader:
image=image.to(device, non_blocking=True)
# print(self.model.vocab_size)
temp_image=self.model.encode_image(image)
image_train.append(temp_image.cpu().detach().numpy())
label_train.append(label.detach().numpy())
image_train=np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
image_var=np.stack([image_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
image_representation = {}
image_var_representation = {}
for i, centroid in enumerate(label_unipue):
image_representation[str(centroid.astype(int))] = image_centroids[i]
image_var_representation[str(centroid.astype(int))]= image_var[i]
return image_representation, image_var_representation
def _tokenize(self, text):
words = text.split(' ')
sub_words = []
keys = []
index = 0
for word in words:
sub = self.bert_tokenizer.tokenize(word)
sub_words += sub
keys.append([index, index + len(sub)])
index += len(sub)
return words, sub_words, keys
def _get_masked(self, text):
words = text.split(' ')
len_text = len(words)
masked_words = []
for i in range(len_text):
masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:])
# list of words
return masked_words
def get_important_scores(self, text, origin_embeds, batch_size, max_length):
# device = origin_embeds.device
masked_words = self._get_masked(text)
masked_texts = [' '.join(words) for words in masked_words] # list of text of masked words
masked_embeds = []
for i in range(0, len(masked_texts), batch_size):
masked_text_input = self.bert_tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
masked_embed = self.ref_net(masked_text_input.text_inputs, attention_mask=masked_text_input.attention_mask)
masked_embeds.append(masked_embed)
masked_embeds = torch.cat(masked_embeds, dim=0)
criterion = torch.nn.KLDivLoss(reduction='none')
import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))
return import_scores.sum(dim=-1)
def adv_loss(self,anchor, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var):
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity(),reduction='none')
negative_dist=(anchor-negetive_mean)**2 / negative_var
positive_dist=(anchor-positive_mean)**2 /positive_var
negatives=torch.exp(negative_dist / self.args.temperature)
positives= torch.exp(positive_dist / self.args.temperature)
loss= torch.log(positives/(positives+negatives)) + self.args.beta* loss1
return loss
def target_adv(self, text_tokens, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
texts=[self.clip_tokenizer.decode(token) for token in text_tokens]
text_inputs = self.bert_tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device, non_blocking=True)
mlm_logits = self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)
#clean state
clean_embeds=self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask)
final_adverse = []
for i, text in enumerate(texts):
important_scores = self.get_important_scores(text, clean_embeds, self.batch_size, self.max_length)
list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)
words, sub_words, keys = self._tokenize(text)
final_words = copy.deepcopy(words)
change = 0
for top_index in list_of_index:
if change >= self.args.num_perturbation:
break
tgt_word = words[top_index[0]]
if tgt_word in filter_words:
continue
if keys[top_index[0]][0] > self.args.max_length - 2:
continue
substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]] # L, k
word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]
substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
self.args.threshold_pred_score)
replace_texts = [' '.join(final_words)]
available_substitutes = [tgt_word]
for substitute_ in substitutes:
substitute = substitute_
if substitute == tgt_word:
continue # filter out original word
if '##' in substitute:
continue # filter out sub-word
if substitute in filter_words:
continue
temp_replace = copy.deepcopy(final_words)
temp_replace[top_index[0]] = substitute
available_substitutes.append(substitute)
replace_texts.append(' '.join(temp_replace))
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
replace_embeds = self.model.encode_text(replace_text_input)
loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
loss = loss.sum(dim=-1)
candidate_idx = loss.argmax()
final_words[top_index[0]] = available_substitutes[candidate_idx]
if available_substitutes[candidate_idx] != tgt_word:
change += 1
final_adverse.append(' '.join(final_words))
return final_adverse
def train_epoch(self):
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_t2i")
all_loss = 0
times = 0
adv_codes=[]
adv_label=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
print(times)
image.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.image_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_image(image)
#targeted sample
np.random.seed(times)
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
target_dataset = data.Subset(self.train_data, select_index)
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
target_image, _, target_label, _ = next(iter(target_subset))
target_image=target_image.to(self.rank, non_blocking=True)
positive_mean=np.stack([self.image_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_var=np.stack([self.image_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
positive_code=self.model.encode_image(target_image)
final_adverse=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
final_text=self.clip_tokenizer.tokenize(final_adverse).to(self.rank, non_blocking=True)
adv_code=self.model.encode_text(final_text)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())
adv_txt=np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label)
retrieval_img, _ = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP_t=cal_map(adv_txt,adv_labels,retrieval_img,retrieval_labels)
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = {
'adv_txt': adv_txt,
'r_img': retrieval_img,
'adv_l': adv_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.device, non_blocking=True)
text = text.to(self.device, non_blocking=True)
index = index.numpy()
with torch.no_grad():
image_feature = self.model.encode_image(image)
text_features = self.model.encode_text(text)
img_buffer[index, :] = image_feature.detach()
text_buffer[index, :] = text_features.detach()
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
self.logger.info("Valid Clean.")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")