更新 train/hash_train.py

This commit is contained in:
liwenyun 2024-06-08 16:03:17 +08:00
parent 97418d42e1
commit 5cce92d540
1 changed files with 364 additions and 454 deletions

View File

@ -1,454 +1,364 @@
from torch.nn.modules import loss from torch.nn.modules import loss
from model.hash_model import DCMHT as DCMHT from model.hash_model import DCMHT as DCMHT
import os import os
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import scipy.io as scio import scipy.io as scio
import numpy as np import numpy as np
from .base import TrainBase from .base import TrainBase
from torch.optim import Adam from model.optimization import BertAdam
import torch.nn.functional as F # from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
# from model.optimization import BertAdam from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss from utils.calc_utils import calc_map_k_matrix as calc_map_k
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from dataset.dataloader import dataloader
from utils.calc_utils import calc_map_k_matrix as calc_map_k import open_clip
from dataset.dataloader import dataloader # from transformers import BertModel
import open_clip
# from transformers import BertModel device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def clamp(delta, clean_imgs): def clamp(delta, clean_imgs):
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1) clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
clamp_delta = clamp_imgs - clean_imgs.data clamp_delta = clamp_imgs - clean_imgs.data
return clamp_delta return clamp_delta
class Trainer(TrainBase):
class Trainer(TrainBase): def __init__(self,
rank=0):
def __init__(self, args = get_args()
rank=0): super(Trainer, self).__init__(args, rank)
args = get_args() self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
super(Trainer, self).__init__(args, rank) text_representation, text_representation=self.generate_mapping()
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) self.image_representation=text_representation
text_mean_representation, text_var_representation=self.generate_mapping() self.text_representation=text_representation
self.text_mean=text_mean_representation self.device=rank
self.text_var=text_var_representation # self.run()
# self.run()
def _init_model(self):
def _init_model(self): self.logger.info("init model.")
self.logger.info("init model.") model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
# self.generator=Generator() self.model= model_clip
# linear = False self.model.eval()
# if self.args.hash_layer == "linear": self.model.float()
# linear = True
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank) def _init_dataset(self):
# self.bert.eval() self.logger.info("init dataset.")
# self.logger.info("ViT+GPT!") self.logger.info(f"Using {self.args.dataset} dataset.")
# HashModel = DCMHT self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
# if self.args.victim_model == 'JDSH': self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
# from model.JDSH import TxtNet, ImgNet self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) indexFile=self.args.index_file,
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank) labelFile=self.args.label_file,
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank) maxWords=self.args.max_words,
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth') imageResolution=self.args.resolution,
# checkpoint=torch.load(path) query_num=self.args.query_num,
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}")) train_num=self.args.train_num,
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}")) seed=self.args.seed)
# self.img_model.eval() self.train_labels = train_data.get_all_label()
# self.txt_model.eval() self.query_labels = query_data.get_all_label()
# elif self.args.victim_model == 'DJSRH': self.retrieval_labels = retrieval_data.get_all_label()
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, self.args.retrieval_num = len(self.retrieval_labels)
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) self.logger.info(f"query shape: {self.query_labels.shape}")
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
# elif self.args.victim_model == 'SSAH': self.train_loader = DataLoader(
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, dataset=train_data,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) batch_size=self.args.batch_size,
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) num_workers=self.args.num_workers,
# elif self.args.victim_model == 'DCHUC': pin_memory=True,
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, shuffle=True
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) )
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) self.query_loader = DataLoader(
dataset=query_data,
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained): batch_size=self.args.batch_size,
# self.logger.info("load pretrained model.") num_workers=self.args.num_workers,
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) pin_memory=True,
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device) shuffle=True
self.model= model_clip )
self.model.eval() self.retrieval_loader = DataLoader(
self.model.float() dataset=retrieval_data,
self.optimizer =Adam(self.model.visual.parameters,lr=self.args.lr ,betas=[0.9,0.98] ) batch_size=self.args.batch_size,
# self.optimizer = BertAdam([ num_workers=self.args.num_workers,
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, pin_memory=True,
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, shuffle=True
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr} )
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
def generate_mapping(self):
# print(self.model) text_train=[]
label_train=[]
def _init_dataset(self): for image, text, label, index in self.train_loader:
self.logger.info("init dataset.") text=text.to(device, non_blocking=True)
self.logger.info(f"Using {self.args.dataset} dataset.") # print(self.model.vocab_size)
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file) temp_text=self.model.encode_text(text)
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file) text_train.append(temp_text.cpu().detach().numpy())
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file) label_train.append(label.detach().numpy())
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file, text_train=np.concatenate(text_train, axis=0)
indexFile=self.args.index_file, label_train=np.concatenate(label_train, axis=0)
labelFile=self.args.label_file, label_unipue=np.unique(label_train,axis=0)
maxWords=self.args.max_words, text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
imageResolution=self.args.resolution, text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
query_num=self.args.query_num,
train_num=self.args.train_num, text_representation = {}
seed=self.args.seed) text_var_representation = {}
self.train_labels = train_data.get_all_label() for i, centroid in enumerate(label_unipue):
self.query_labels = query_data.get_all_label() text_representation[centroid.tobytes()] = text_centroids[i]
self.retrieval_labels = retrieval_data.get_all_label() text_var_representation[centroid.tobytes()]= text_var[i]
# self.args.retrieval_num = len(self.retrieval_labels) return text_representation, text_var_representation
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}") def target_adv(self, image, positive, negative,
self.train_loader = DataLoader( epsilon=0.03125, alpha=3/255, num_iter=100):
dataset=train_data,
batch_size=self.args.batch_size, delta = torch.zeros_like(image,requires_grad=True)
num_workers=self.args.num_workers, one=torch.zeros_like(positive)
pin_memory=True, alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
shuffle=True for i in range(num_iter):
) self.model.zero_grad()
self.query_loader = DataLoader( anchor=self.model.encode_image(image+delta)
dataset=query_data, loss=alienation_loss(anchor, positive, negative)
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
pin_memory=True, loss.backward(retain_graph=True)
shuffle=True delta.data = delta - alpha * delta.grad.detach().sign()
) delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
self.retrieval_loader = DataLoader( delta.grad.zero_()
dataset=retrieval_data,
batch_size=self.args.batch_size, return delta.detach()
num_workers=self.args.num_workers,
pin_memory=True, def train_epoch(self, epoch):
shuffle=True self.change_state(mode="valid")
) self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
def generate_mapping(self): all_loss = 0
text_train=[] times = 0
label_train=[] adv_images=[]
# image_train=[] adv_labels=[]
# self.change_state(mode="valid") texts=[]
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
# image=image.to(self.device, non_blocking=True) self.global_step += 1
text=text.to(self.device, non_blocking=True) times += 1
temp_text=self.model.encode_text(text) image.float()
# temp_image=self.model.encode_image(image) if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
# image_train.append(temp_image.cpu().detach().numpy()) label = torch.ones([image.shape[0]], dtype=torch.int)
text_train.append(temp_text.cpu().detach().numpy()) label = label.diag()
label_train.append(label.detach().numpy()) image = image.to(self.rank, non_blocking=True)
text_train=np.concatenate(text_train, axis=0) text = text.to(self.rank, non_blocking=True)
# image_train=np.concatenate(image_train, axis=0) index = index.numpy()
label_train=np.concatenate(label_train, axis=0) image_anchor=self.image_representation(label.detach().cpu().numpy())
label_unipue=np.unique(label_train,axis=0) text_anchor=self.text_representation(label.detach().cpu().numpy())
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0) target_label=label.flip(dims=[0])
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0) target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
text_mean_representation = {} positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
text_var_representation = {} delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
for i, centroid in enumerate(label_unipue): torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
text_mean_representation[centroid.tobytes()] = text_centroids[i] adv_image=delta+image
text_var_representation[centroid.tobytes()]= text_var[i] adv_images.append(adv_image)
return text_mean_representation, text_var_representation adv_labels.append(target_label)
texts.append(text)
def target_adv(self, image, positive, positive_mean,positive_var, negative, negative_mean, negative_var, return adv_images, texts, adv_labels
epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
# clean_output = self.model.encode_image(image) def train(self):
one=torch.zeros_like(positive) self.logger.info("Start train.")
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter): for epoch in range(self.args.epochs):
self.model.zero_grad() self.train_epoch(epoch)
anchor=self.model.encode_image(image+delta) self.valid(epoch)
loss1=alienation_loss(anchor, positive, negative) self.save_model(epoch)
loss=loss1 + self.args.beta * self.distribution_loss(anchor,positive_mean,positive_var,negative_mean, negative_var)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign() def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_() s = torch.matmul(a, b.t())
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
return delta.detach()
return b_loss
def train_epoch(self):
self.change_state(mode="valid") def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) """
all_loss = 0 """
times = 0 kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve") print("mean", torch.mean(a - b))
os.makedirs(save_dir, exist_ok=True) print("kl", kl_divergence)
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) return kl_divergence
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
adv_images=[]
adv_labels=[] def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
# target_texts=[]
for image, text, label, index in self.train_loader: # $\vartheta$
self.global_step += 1 vartheta = self.args.vartheta
times += 1 if self.args.sim_threshold != 0:
image.float() threshold = self.args.sim_threshold
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag() positive_similarity = similarity * label_sim
image = image.to(self.rank, non_blocking=True) # 只要cosine为负值的全都算为计算正确了因为优化到2确实很难。
text = text.to(self.rank, non_blocking=True) negative_similarity = similarity * (1 - label_sim)
index = index.numpy()
# image_anchor=self.image_representation(label.detach().cpu().numpy()) if self.args.similarity_function == "cosine":
negetive_mean=self.text_mean(label.detach().cpu().numpy()) positive_similarity = positive_similarity.clip(threshold) - threshold
negetive_var=self.text_var(label.detach().cpu().numpy()) negative_similarity = negative_similarity.clip(max=1.)
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0) negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
negetive_code=self.model.encode_text(text) elif self.args.similarity_function == "euclidean":
target_label=label.flip(dims=[0]) # 有euclidean距离可知当有一半长度的hash码不同时其negative_similarity距离应该是长度concat操作将outputdim翻倍所以这里clip掉认为认定的值
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy()) # 人为认定的最大值是一半长度的hash码不同。
positive_mean=self.text_mean(target_label.detach().cpu().numpy()) max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
positive_var=self.text_var(target_label.detach().cpu().numpy()) negative_similarity = negative_similarity.clip(max=max_value)
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0) negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
positive_code=self.model.encode_text(text.flip(dims=[0]))
# print("text shape:", text.shape) if self.args.loss_type == "l1":
# index = index.numpy() positive_loss = positive_similarity.mean()
# print(text.shape) negative_loss = negative_similarity.mean()
delta=self.target_adv(image,positive_code,torch.from_numpy(positive_mean).to(self.rank, non_blocking=True), torch.from_numpy(positive_var).to(self.rank, non_blocking=True), elif self.args.loss_type == "l2":
negetive_code, torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True), torch.from_numpy(negetive_var).to(self.rank, non_blocking=True)) positive_loss = torch.pow(positive_similarity, 2).mean()
adv_image=delta+image negative_loss = torch.pow(negative_similarity, 2).mean()
adv_images.append(self.model.encode_image(adv_image)) else:
adv_labels.append(target_label) raise ValueError("argument of loss_type is not support.")
# target_texts.append(self.model.encode_text(text))
adv_image=torch.cat(adv_image).to(self.device) return similarity, positive_loss, negative_loss
adv_labels=torch.cat(adv_labels).to(self.device)
def make_hash_code(self, code: list) -> torch.Tensor:
mAPi2t = calc_map_k(adv_images, retrieval_txt, adv_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(adv_images, retrieval_img, adv_labels, self.retrieval_labels, None, self.rank) code = torch.stack(code)
self.logger.info(f">>>>>> t-MAP(i->t): {mAPi2t}, t-MAP(t->t): {mAPt2t}") # print(code.shape)
adv_images = adv_images.cpu().detach().numpy() code = code.permute(1, 0, 2)
# query_txt = query_txt.cpu().detach().numpy() hash_code = torch.argmax(code, dim=-1)
retrieval_img = retrieval_img.cpu().detach().numpy() hash_code[torch.where(hash_code == 0)] = -1
retrieval_txt = retrieval_txt.cpu().detach().numpy() hash_code = hash_code.float()
adv_labels = adv_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy() return hash_code
result_dict = { def get_code(self, data_loader, length: int):
'adv_img': adv_images,
# 'q_txt': query_txt, img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
'r_img': retrieval_img, text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
'r_txt': retrieval_txt,
'adv_l': adv_labels, for image, text, label, index in tqdm(data_loader):
'r_l': retrieval_labels image = image.to(self.rank, non_blocking=True)
} text = text.to(self.rank, non_blocking=True)
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict) index = index.numpy()
self.logger.info(">>>>>> save all data!") image_hash=self.img_model(image)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}") text_feat=self.bert(text)[0]
# return adv_images, texts, adv_labels text_hash=self.txt_model(text_feat)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
def train(self): return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
self.logger.info("Start train.")
self.valid() def get_adv_code(self, adv_data_list,text_list):
self.train_epoch()
# self.valid() img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
# for epoch in range(self.args.epochs): text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
# self.train_epoch()
# self.valid(epoch) for i in tqdm(range(len(adv_data_list))):
# self.save_model(epoch) image = adv_data_list[i].to(self.rank, non_blocking=True)
text = text_list[i].to(self.rank, non_blocking=True)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}") # index = index.numpy()
image_hash=self.img_model(image)
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor): text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
s = torch.matmul(a, b.t()) # text_hash = self.make_hash_code(text_hash)
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s))) # image_hash.to(self.rank)
# text_hash.to(self.rank)
return b_loss img_buffer[i, :] = image_hash.data
text_buffer[i, :] = text_hash.data
def distribution_loss(self, x: torch.Tensor, positive_mean,positive_var, negative_mean, negative_var):
""" return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
"""
norm_fun= lambda mean, var, x: 50- torch.mean(torch.exp(-(x-mean) **2 /(2*var)) /(2* torch.pi * var))
positive_distribution=norm_fun(positive_mean,positive_var,x) def valid_attack(self,adv_images, texts, adv_labels):
negative_distribution=norm_fun(negative_mean,negative_var,x) save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
# alienation_loss=nn.MarginRankingLoss() os.makedirs(save_dir, exist_ok=True)
return F.margin_ranking_loss(positive_distribution,negative_distribution,1)
def test(self, mode_name="i2t"):
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05): if self.args.pretrained == "":
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
# $\vartheta$ self.change_state(mode="valid")
vartheta = self.args.vartheta save_dir = os.path.join(self.args.save_dir, "PR_cruve")
if self.args.sim_threshold != 0: os.makedirs(save_dir, exist_ok=True)
threshold = self.args.sim_threshold query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
positive_similarity = similarity * label_sim # print("map map")
# 只要cosine为负值的全都算为计算正确了因为优化到2确实很难。 mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
negative_similarity = similarity * (1 - label_sim) mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
if self.args.similarity_function == "cosine": self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
positive_similarity = positive_similarity.clip(threshold) - threshold self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
negative_similarity = negative_similarity.clip(max=1.)
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity query_img = query_img.cpu().detach().numpy()
elif self.args.similarity_function == "euclidean": query_txt = query_txt.cpu().detach().numpy()
# 有euclidean距离可知当有一半长度的hash码不同时其negative_similarity距离应该是长度concat操作将outputdim翻倍所以这里clip掉认为认定的值 retrieval_img = retrieval_img.cpu().detach().numpy()
# 人为认定的最大值是一半长度的hash码不同。 retrieval_txt = retrieval_txt.cpu().detach().numpy()
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5 query_labels = self.query_labels.numpy()
negative_similarity = negative_similarity.clip(max=max_value) retrieval_labels = self.retrieval_labels.numpy()
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
result_dict = {
if self.args.loss_type == "l1": 'q_img': query_img,
positive_loss = positive_similarity.mean() 'q_txt': query_txt,
negative_loss = negative_similarity.mean() 'r_img': retrieval_img,
elif self.args.loss_type == "l2": 'r_txt': retrieval_txt,
positive_loss = torch.pow(positive_similarity, 2).mean() 'q_l': query_labels,
negative_loss = torch.pow(negative_similarity, 2).mean() 'r_l': retrieval_labels
else: }
raise ValueError("argument of loss_type is not support.") scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
return similarity, positive_loss, negative_loss
def make_hash_code(self, code: list) -> torch.Tensor: def valid(self, epoch):
self.logger.info("Valid.")
code = torch.stack(code) self.change_state(mode="valid")
# print(code.shape) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
code = code.permute(1, 0, 2) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
hash_code = torch.argmax(code, dim=-1) # print("get all code")
hash_code[torch.where(hash_code == 0)] = -1 mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
hash_code = hash_code.float() # print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
return hash_code mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
def get_code(self, data_loader, length: int): if self.max_mapi2t < mAPi2t:
self.best_epoch_i = epoch
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank) self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
if self.max_mapt2i < mAPt2i:
for image, text, label, index in tqdm(data_loader): self.best_epoch_t = epoch
image = image.to(self.rank, non_blocking=True) self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
text = text.to(self.rank, non_blocking=True) self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
index = index.numpy() self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
image_hash=self.model.encode_image(image) MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
# text_feat=self.bert(text)[0]
text_hash=self.model.encode_text(text) def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
query_img = query_img.cpu().detach().numpy()
def get_adv_code(self, adv_data_list,text_list): query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank) retrieval_txt = retrieval_txt.cpu().detach().numpy()
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank) query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
for i in tqdm(range(len(adv_data_list))):
image = adv_data_list[i].to(self.rank, non_blocking=True) result_dict = {
text = text_list[i].to(self.rank, non_blocking=True) 'q_img': query_img,
# index = index.numpy() 'q_txt': query_txt,
image_hash=self.img_model(image) 'r_img': retrieval_img,
text_feat=self.bert(text)[0] 'r_txt': retrieval_txt,
text_hash=self.txt_model(text_feat) 'q_l': query_labels,
# text_hash = self.make_hash_code(text_hash) 'r_l': retrieval_labels
# image_hash.to(self.rank) }
# text_hash.to(self.rank) scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
img_buffer[i, :] = image_hash.data self.logger.info(f">>>>>> save best {mode_name} data!")
text_buffer[i, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
if self.args.pretrained == "":
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
def valid(self, epoch):
self.logger.info("Valid.")
# self.change_state(mode="valid")
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# print("get all code")
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
if self.max_mapi2t < mAPi2t:
self.best_epoch_i = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
if self.max_mapt2i < mAPt2i:
self.best_epoch_t = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True)
query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'q_img': query_img,
'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'q_l': query_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
self.logger.info(f">>>>>> save best {mode_name} data!")