313 lines
13 KiB
Python
313 lines
13 KiB
Python
from torch.nn.modules import loss
|
|
# from model.hash_model import DCMHT as DCMHT
|
|
import os
|
|
from tqdm import tqdm
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import DataLoader
|
|
import torch.utils.data as data
|
|
import scipy.io as scio
|
|
import numpy as np
|
|
|
|
from .base import TrainBase
|
|
from torch.nn import functional as F
|
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
|
from utils.calc_utils import cal_map, cal_pr
|
|
from dataset.dataloader import dataloader
|
|
import clip
|
|
# from transformers import BertModel
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
def clamp(delta, clean_imgs):
|
|
|
|
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
|
|
clamp_delta = clamp_imgs - clean_imgs.data
|
|
|
|
return clamp_delta
|
|
|
|
class Trainer(TrainBase):
|
|
|
|
def __init__(self,
|
|
rank=0):
|
|
args = get_args()
|
|
super(Trainer, self).__init__(args, rank)
|
|
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
|
text_mean, text_var=self.generate_mapping()
|
|
self.text_mean=text_mean
|
|
self.text_var=text_var
|
|
self.device=rank
|
|
# self.run()
|
|
|
|
def _init_model(self):
|
|
self.logger.info("init model.")
|
|
model_clip, preprocess = clip.load(self.args.victim, device=device)
|
|
self.model= model_clip
|
|
self.model.eval()
|
|
self.model.float()
|
|
|
|
def _init_dataset(self):
|
|
self.logger.info("init dataset.")
|
|
self.logger.info(f"Using {self.args.dataset} dataset.")
|
|
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
|
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
|
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
|
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
|
indexFile=self.args.index_file,
|
|
labelFile=self.args.label_file,
|
|
maxWords=self.args.max_words,
|
|
imageResolution=self.args.resolution,
|
|
query_num=self.args.query_num,
|
|
train_num=self.args.train_num,
|
|
seed=self.args.seed)
|
|
self.train_labels = train_data.get_all_label()
|
|
self.query_labels = query_data.get_all_label()
|
|
self.retrieval_labels = retrieval_data.get_all_label()
|
|
self.args.retrieval_num = len(self.retrieval_labels)
|
|
self.logger.info(f"query shape: {self.query_labels.shape}")
|
|
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
|
self.train_loader = DataLoader(
|
|
dataset=train_data,
|
|
batch_size=self.args.batch_size,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True,
|
|
shuffle=True
|
|
)
|
|
self.query_loader = DataLoader(
|
|
dataset=query_data,
|
|
batch_size=self.args.batch_size,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True,
|
|
shuffle=True
|
|
)
|
|
self.retrieval_loader = DataLoader(
|
|
dataset=retrieval_data,
|
|
batch_size=self.args.batch_size,
|
|
num_workers=self.args.num_workers,
|
|
pin_memory=True,
|
|
shuffle=True
|
|
)
|
|
self.train_data=train_data
|
|
|
|
|
|
def generate_mapping(self):
|
|
text_train=[]
|
|
label_train=[]
|
|
for image, text, label, index in self.train_loader:
|
|
text=text.to(device, non_blocking=True)
|
|
# print(self.model.vocab_size)
|
|
temp_text=self.model.encode_text(text)
|
|
text_train.append(temp_text.cpu().detach().numpy())
|
|
label_train.append(label.detach().numpy())
|
|
text_train=np.concatenate(text_train, axis=0)
|
|
label_train=np.concatenate(label_train, axis=0)
|
|
label_unipue=np.unique(label_train,axis=0)
|
|
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
|
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
|
|
|
text_representation = {}
|
|
text_var_representation = {}
|
|
for i, centroid in enumerate(label_unipue):
|
|
text_representation[str(centroid.astype(int))] = text_centroids[i]
|
|
text_var_representation[str(centroid.astype(int))]= text_var[i]
|
|
return text_representation, text_var_representation
|
|
|
|
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
|
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
|
|
|
delta = torch.zeros_like(image,requires_grad=True)
|
|
for i in range(num_iter):
|
|
self.model.zero_grad()
|
|
anchor=self.model.encode_image(image+delta)
|
|
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
|
|
negative_dist=(anchor-negetive_mean)**2 / negative_var
|
|
positive_dist=(anchor-positive_mean)**2 /positive_var
|
|
negatives=torch.exp(negative_dist / temperature)
|
|
positives= torch.exp(positive_dist / temperature)
|
|
loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1
|
|
loss.backward(retain_graph=True)
|
|
delta.data = delta - alpha * delta.grad.detach().sign()
|
|
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
|
delta.grad.zero_()
|
|
adv_code=self.model.encode_image(image+delta)
|
|
return delta.detach() , adv_code
|
|
|
|
def train_epoch(self):
|
|
self.change_state(mode="valid")
|
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
|
all_loss = 0
|
|
times = 0
|
|
adv_codes=[]
|
|
adv_label=[]
|
|
for image, text, label, index in self.train_loader:
|
|
self.global_step += 1
|
|
times += 1
|
|
print(times)
|
|
image.float()
|
|
image = image.to(self.rank, non_blocking=True)
|
|
text = text.to(self.rank, non_blocking=True)
|
|
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
|
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
|
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
|
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
|
negetive_code=self.model.encode_text(text)
|
|
|
|
#targeted sample
|
|
np.random.seed(times)
|
|
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
|
|
target_dataset = data.Subset(self.train_data, select_index)
|
|
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
|
|
_, target_text, target_label, _ = next(iter(target_subset))
|
|
target_text=target_text.to(self.rank, non_blocking=True)
|
|
positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
|
positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
|
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
|
|
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
|
|
positive_code=self.model.encode_text(target_text)
|
|
|
|
|
|
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
|
positive_code,positive_mean,positive_var)
|
|
adv_codes.append(adv_code.cpu().detach().numpy())
|
|
adv_label.append(target_label.numpy())
|
|
adv_img=np.concatenate(adv_codes)
|
|
adv_labels=np.concatenate(adv_label)
|
|
|
|
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
|
|
|
|
|
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
|
retrieval_labels = self.retrieval_labels.numpy()
|
|
|
|
|
|
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
|
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
|
pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels)
|
|
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
|
result_dict = {
|
|
'adv_img': adv_img,
|
|
'r_txt': retrieval_txt,
|
|
'adv_l': adv_labels,
|
|
'r_l': retrieval_labels,
|
|
# 'q_l':query_labels
|
|
# 'pr': pr,
|
|
'pr_t': pr_t
|
|
}
|
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
|
self.logger.info(">>>>>> save all data!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
self.logger.info("Start train.")
|
|
|
|
for epoch in range(self.args.epochs):
|
|
self.train_epoch(epoch)
|
|
self.valid(epoch)
|
|
self.save_model(epoch)
|
|
|
|
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
|
|
|
|
|
|
|
def make_hash_code(self, code: list) -> torch.Tensor:
|
|
|
|
code = torch.stack(code)
|
|
# print(code.shape)
|
|
code = code.permute(1, 0, 2)
|
|
hash_code = torch.argmax(code, dim=-1)
|
|
hash_code[torch.where(hash_code == 0)] = -1
|
|
hash_code = hash_code.float()
|
|
|
|
return hash_code
|
|
|
|
def get_code(self, data_loader, length: int):
|
|
|
|
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
|
|
for image, text, label, index in tqdm(data_loader):
|
|
image = image.to(self.device, non_blocking=True)
|
|
text = text.to(self.device, non_blocking=True)
|
|
index = index.numpy()
|
|
with torch.no_grad():
|
|
image_feature = self.model.encode_image(image)
|
|
text_features = self.model.encode_text(text)
|
|
img_buffer[index, :] = image_feature.detach()
|
|
text_buffer[index, :] = text_features.detach()
|
|
|
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
|
|
|
|
|
|
|
|
|
def valid_attack(self,adv_images, texts, adv_labels):
|
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
|
|
def test(self, mode_name="i2t"):
|
|
self.logger.info("Valid Clean.")
|
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
|
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
|
|
|
|
query_img = query_img.cpu().detach().numpy()
|
|
query_txt = query_txt.cpu().detach().numpy()
|
|
retrieval_img = retrieval_img.cpu().detach().numpy()
|
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
|
query_labels = self.query_labels.numpy()
|
|
retrieval_labels = self.retrieval_labels.numpy()
|
|
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
|
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
|
pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels)
|
|
pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels)
|
|
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
|
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
|
result_dict = {
|
|
'q_img': query_img,
|
|
'q_txt': query_txt,
|
|
'r_img': retrieval_img,
|
|
'r_txt': retrieval_txt,
|
|
'q_l': query_labels,
|
|
'r_l': retrieval_labels,
|
|
'pr_i2t': pr_i2t,
|
|
'pr_t2i': pr_t2i
|
|
}
|
|
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
|
self.logger.info(">>>>>> save all data!")
|
|
|
|
|
|
|
|
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
|
|
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
query_img = query_img.cpu().detach().numpy()
|
|
query_txt = query_txt.cpu().detach().numpy()
|
|
retrieval_img = retrieval_img.cpu().detach().numpy()
|
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
|
query_labels = self.query_labels.numpy()
|
|
retrieval_labels = self.retrieval_labels.numpy()
|
|
|
|
result_dict = {
|
|
'q_img': query_img,
|
|
'q_txt': query_txt,
|
|
'r_img': retrieval_img,
|
|
'r_txt': retrieval_txt,
|
|
'q_l': query_labels,
|
|
'r_l': retrieval_labels
|
|
}
|
|
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
|
self.logger.info(f">>>>>> save best {mode_name} data!")
|
|
|
|
|