advclip/train/hash_train.py

313 lines
13 KiB
Python

from torch.nn.modules import loss
# from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
import scipy.io as scio
import numpy as np
from .base import TrainBase
from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader
import clip
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def clamp(delta, clean_imgs):
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
clamp_delta = clamp_imgs - clean_imgs.data
return clamp_delta
class Trainer(TrainBase):
def __init__(self,
rank=0):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
text_mean, text_var=self.generate_mapping()
self.text_mean=text_mean
self.text_var=text_var
self.device=rank
# self.run()
def _init_model(self):
self.logger.info("init model.")
model_clip, preprocess = clip.load(self.args.victim, device=device)
self.model= model_clip
self.model.eval()
self.model.float()
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.train_data=train_data
def generate_mapping(self):
text_train=[]
label_train=[]
for image, text, label, index in self.train_loader:
text=text.to(device, non_blocking=True)
# print(self.model.vocab_size)
temp_text=self.model.encode_text(text)
text_train.append(temp_text.cpu().detach().numpy())
label_train.append(label.detach().numpy())
text_train=np.concatenate(text_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
text_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[str(centroid.astype(int))] = text_centroids[i]
text_var_representation[str(centroid.astype(int))]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
delta = torch.zeros_like(image,requires_grad=True)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
negative_dist=(anchor-negetive_mean)**2 / negative_var
positive_dist=(anchor-positive_mean)**2 /positive_var
negatives=torch.exp(negative_dist / temperature)
positives= torch.exp(positive_dist / temperature)
loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_()
adv_code=self.model.encode_image(image+delta)
return delta.detach() , adv_code
def train_epoch(self):
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
all_loss = 0
times = 0
adv_codes=[]
adv_label=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
print(times)
image.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_text(text)
#targeted sample
np.random.seed(times)
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
target_dataset = data.Subset(self.train_data, select_index)
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
_, target_text, target_label, _ = next(iter(target_subset))
target_text=target_text.to(self.rank, non_blocking=True)
positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
positive_code=self.model.encode_text(target_text)
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())
adv_img=np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
pr_t=cal_pr(retrieval_txt,adv_img,retrieval_labels,adv_labels)
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = {
'adv_img': adv_img,
'r_txt': retrieval_txt,
'adv_l': adv_labels,
'r_l': retrieval_labels,
# 'q_l':query_labels
# 'pr': pr,
'pr_t': pr_t
}
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def train(self):
self.logger.info("Start train.")
for epoch in range(self.args.epochs):
self.train_epoch(epoch)
self.valid(epoch)
self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def make_hash_code(self, code: list) -> torch.Tensor:
code = torch.stack(code)
# print(code.shape)
code = code.permute(1, 0, 2)
hash_code = torch.argmax(code, dim=-1)
hash_code[torch.where(hash_code == 0)] = -1
hash_code = hash_code.float()
return hash_code
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.device, non_blocking=True)
text = text.to(self.device, non_blocking=True)
index = index.numpy()
with torch.no_grad():
image_feature = self.model.encode_image(image)
text_features = self.model.encode_text(text)
img_buffer[index, :] = image_feature.detach()
text_buffer[index, :] = text_features.detach()
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
self.logger.info("Valid Clean.")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
pr_i2t=cal_pr(retrieval_txt,query_img,retrieval_labels,query_labels)
pr_t2i=cal_pr(retrieval_img,query_txt,retrieval_labels,query_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels,
'pr_i2t': pr_i2t,
'pr_t2i': pr_t2i
}
scio.savemat(os.path.join(save_dir, str(self.args.victim).replace("/", "_") + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")