from torch.nn.modules import loss from model.hash_model import DCMHT as DCMHT import os from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader import scipy.io as scio import numpy as np from .base import TrainBase from model.optimization import BertAdam # from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils.calc_utils import calc_map_k_matrix as calc_map_k from dataset.dataloader import dataloader import open_clip # from transformers import BertModel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def clamp(delta, clean_imgs): clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) clamp_delta = clamp_imgs - clean_imgs.data return clamp_delta class Trainer(TrainBase): def __init__(self, rank=0): args = get_args() super(Trainer, self).__init__(args, rank) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) text_representation, text_representation=self.generate_mapping() self.image_representation=text_representation self.text_representation=text_representation self.device=rank # self.run() def _init_model(self): self.logger.info("init model.") model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device) self.model= model_clip self.model.eval() self.model.float() def _init_dataset(self): self.logger.info("init dataset.") self.logger.info(f"Using {self.args.dataset} dataset.") self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, indexFile=self.args.index_file, labelFile=self.args.label_file, maxWords=self.args.max_words, imageResolution=self.args.resolution, query_num=self.args.query_num, train_num=self.args.train_num, seed=self.args.seed) self.train_labels = train_data.get_all_label() self.query_labels = query_data.get_all_label() self.retrieval_labels = retrieval_data.get_all_label() self.args.retrieval_num = len(self.retrieval_labels) self.logger.info(f"query shape: {self.query_labels.shape}") self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") self.train_loader = DataLoader( dataset=train_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.query_loader = DataLoader( dataset=query_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) self.retrieval_loader = DataLoader( dataset=retrieval_data, batch_size=self.args.batch_size, num_workers=self.args.num_workers, pin_memory=True, shuffle=True ) def generate_mapping(self): text_train=[] label_train=[] for image, text, label, index in self.train_loader: text=text.to(device, non_blocking=True) # print(self.model.vocab_size) temp_text=self.model.encode_text(text) text_train.append(temp_text.cpu().detach().numpy()) label_train.append(label.detach().numpy()) text_train=np.concatenate(text_train, axis=0) label_train=np.concatenate(label_train, axis=0) label_unipue=np.unique(label_train,axis=0) text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) text_representation = {} text_var_representation = {} for i, centroid in enumerate(label_unipue): text_representation[centroid.tobytes()] = text_centroids[i] text_var_representation[centroid.tobytes()]= text_var[i] return text_representation, text_var_representation def target_adv(self, image, positive, negative, epsilon=0.03125, alpha=3/255, num_iter=100): delta = torch.zeros_like(image,requires_grad=True) one=torch.zeros_like(positive) alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) for i in range(num_iter): self.model.zero_grad() anchor=self.model.encode_image(image+delta) loss=alienation_loss(anchor, positive, negative) loss.backward(retain_graph=True) delta.data = delta - alpha * delta.grad.detach().sign() delta.data =clamp(delta, image).clamp(-epsilon, epsilon) delta.grad.zero_() return delta.detach() # def train_epoch(self, epoch): # self.change_state(mode="valid") # self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) # all_loss = 0 # times = 0 # adv_images=[] # adv_labels=[] # texts=[] # for image, text, label, index in self.train_loader: # self.global_step += 1 # times += 1 # image.float() # if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: # label = torch.ones([image.shape[0]], dtype=torch.int) # label = label.diag() # image = image.to(self.rank, non_blocking=True) # text = text.to(self.rank, non_blocking=True) # index = index.numpy() # image_anchor=self.image_representation(label.detach().cpu().numpy()) # text_anchor=self.text_representation(label.detach().cpu().numpy()) # negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) # target_label=label.flip(dims=[0]) # target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) # target_text_anchor=self.text_representation(target_label.detach().cpu().numpy()) # positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) # delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True), # torch.from_numpy(negetive_code).to(self.rank, non_blocking=True)) # adv_image=delta+image # adv_images.append(adv_image) # adv_labels.append(target_label) # texts.append(text) # return adv_images, texts, adv_labels def train(self): self.logger.info("Start train.") for epoch in range(self.args.epochs): self.train_epoch(epoch) self.valid(epoch) self.save_model(epoch) self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") def make_hash_code(self, code: list) -> torch.Tensor: code = torch.stack(code) # print(code.shape) code = code.permute(1, 0, 2) hash_code = torch.argmax(code, dim=-1) hash_code[torch.where(hash_code == 0)] = -1 hash_code = hash_code.float() return hash_code def get_code(self, data_loader, length: int): img_buffer = [] text_buffer = [] for image, text, label, index in tqdm(data_loader): image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) index = index.numpy() image_hash=self.model.encode_image(image) # text_feat=self.bert(text)[0] text_hash=self.model.encode_text(text) img_buffer[index, :] = image_hash.data text_buffer[index, :] = text_hash.data return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) def get_adv_code(self, adv_data_list,text_list): img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank) text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank) for i in tqdm(range(len(adv_data_list))): image = adv_data_list[i].to(self.rank, non_blocking=True) text = text_list[i].to(self.rank, non_blocking=True) # index = index.numpy() image_hash=self.img_model(image) text_feat=self.bert(text)[0] text_hash=self.txt_model(text_feat) # text_hash = self.make_hash_code(text_hash) # image_hash.to(self.rank) # text_hash.to(self.rank) img_buffer[i, :] = image_hash.data text_buffer[i, :] = text_hash.data return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) def valid_attack(self,adv_images, texts, adv_labels): save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") os.makedirs(save_dir, exist_ok=True) def test(self, mode_name="i2t"): self.logger.info("Valid Clean.") # if self.args.pretrained == "": # raise RuntimeError("test step must load a model! please set the --pretrained argument.") # self.change_state(mode="valid") save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # print("map map") mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) self.max_mapt2i = max(self.max_mapt2i, mAPt2i) self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}") query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) self.logger.info(">>>>>> save all data!") def valid(self, epoch): self.logger.info("Valid.") self.change_state(mode="valid") query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) # print("get all code") mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # print("map map") mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) if self.max_mapi2t < mAPi2t: self.best_epoch_i = epoch self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") self.max_mapi2t = max(self.max_mapi2t, mAPi2t) if self.max_mapt2i < mAPt2i: self.best_epoch_t = epoch self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") self.max_mapt2i = max(self.max_mapt2i, mAPt2i) self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): save_dir = os.path.join(self.args.save_dir, "PR_cruve") os.makedirs(save_dir, exist_ok=True) query_img = query_img.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy() retrieval_img = retrieval_img.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy() query_labels = self.query_labels.numpy() retrieval_labels = self.retrieval_labels.numpy() result_dict = { 'q_img': query_img, 'q_txt': query_txt, 'r_img': retrieval_img, 'r_txt': retrieval_txt, 'q_l': query_labels, 'r_l': retrieval_labels } scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) self.logger.info(f">>>>>> save best {mode_name} data!")