advclip/train/hash_train.py

365 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.nn.modules import loss
from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import scipy.io as scio
import numpy as np
from .base import TrainBase
from model.optimization import BertAdam
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader
import open_clip
# from transformers import BertModel
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def clamp(delta, clean_imgs):
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
clamp_delta = clamp_imgs - clean_imgs.data
return clamp_delta
class Trainer(TrainBase):
def __init__(self,
rank=0):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
text_representation, text_representation=self.generate_mapping()
self.image_representation=text_representation
self.text_representation=text_representation
self.device=rank
# self.run()
def _init_model(self):
self.logger.info("init model.")
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
self.model= model_clip
self.model.eval()
self.model.float()
def _init_dataset(self):
self.logger.info("init dataset.")
self.logger.info(f"Using {self.args.dataset} dataset.")
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
indexFile=self.args.index_file,
labelFile=self.args.label_file,
maxWords=self.args.max_words,
imageResolution=self.args.resolution,
query_num=self.args.query_num,
train_num=self.args.train_num,
seed=self.args.seed)
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
dataset=train_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.query_loader = DataLoader(
dataset=query_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
self.retrieval_loader = DataLoader(
dataset=retrieval_data,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True,
shuffle=True
)
def generate_mapping(self):
text_train=[]
label_train=[]
for image, text, label, index in self.train_loader:
text=text.to(device, non_blocking=True)
# print(self.model.vocab_size)
temp_text=self.model.encode_text(text)
text_train.append(temp_text.cpu().detach().numpy())
label_train.append(label.detach().numpy())
text_train=np.concatenate(text_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
text_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[centroid.tobytes()] = text_centroids[i]
text_var_representation[centroid.tobytes()]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, positive, negative,
epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss=alienation_loss(anchor, positive, negative)
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_()
return delta.detach()
def train_epoch(self, epoch):
self.change_state(mode="valid")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0
times = 0
adv_images=[]
adv_labels=[]
texts=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
image.float()
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_anchor=self.image_representation(label.detach().cpu().numpy())
text_anchor=self.text_representation(label.detach().cpu().numpy())
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
target_label=label.flip(dims=[0])
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
adv_image=delta+image
adv_images.append(adv_image)
adv_labels.append(target_label)
texts.append(text)
return adv_images, texts, adv_labels
def train(self):
self.logger.info("Start train.")
for epoch in range(self.args.epochs):
self.train_epoch(epoch)
self.valid(epoch)
self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
s = torch.matmul(a, b.t())
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
return b_loss
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
"""
"""
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
print("mean", torch.mean(a - b))
print("kl", kl_divergence)
return kl_divergence
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
# $\vartheta$
vartheta = self.args.vartheta
if self.args.sim_threshold != 0:
threshold = self.args.sim_threshold
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
positive_similarity = similarity * label_sim
# 只要cosine为负值的全都算为计算正确了因为优化到2确实很难。
negative_similarity = similarity * (1 - label_sim)
if self.args.similarity_function == "cosine":
positive_similarity = positive_similarity.clip(threshold) - threshold
negative_similarity = negative_similarity.clip(max=1.)
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
elif self.args.similarity_function == "euclidean":
# 有euclidean距离可知当有一半长度的hash码不同时其negative_similarity距离应该是长度concat操作将outputdim翻倍所以这里clip掉认为认定的值
# 人为认定的最大值是一半长度的hash码不同。
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
negative_similarity = negative_similarity.clip(max=max_value)
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
if self.args.loss_type == "l1":
positive_loss = positive_similarity.mean()
negative_loss = negative_similarity.mean()
elif self.args.loss_type == "l2":
positive_loss = torch.pow(positive_similarity, 2).mean()
negative_loss = torch.pow(negative_similarity, 2).mean()
else:
raise ValueError("argument of loss_type is not support.")
return similarity, positive_loss, negative_loss
def make_hash_code(self, code: list) -> torch.Tensor:
code = torch.stack(code)
# print(code.shape)
code = code.permute(1, 0, 2)
hash_code = torch.argmax(code, dim=-1)
hash_code[torch.where(hash_code == 0)] = -1
hash_code = hash_code.float()
return hash_code
def get_code(self, data_loader, length: int):
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
for image, text, label, index in tqdm(data_loader):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def get_adv_code(self, adv_data_list,text_list):
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
for i in tqdm(range(len(adv_data_list))):
image = adv_data_list[i].to(self.rank, non_blocking=True)
text = text_list[i].to(self.rank, non_blocking=True)
# index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
# text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[i, :] = image_hash.data
text_buffer[i, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
if self.args.pretrained == "":
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def valid(self, epoch):
self.logger.info("Valid.")
self.change_state(mode="valid")
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# print("get all code")
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
if self.max_mapi2t < mAPi2t:
self.best_epoch_i = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
if self.max_mapt2i < mAPt2i:
self.best_epoch_t = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")