更新 train/hash_train.py
This commit is contained in:
parent
97418d42e1
commit
5cce92d540
|
|
@ -1,454 +1,364 @@
|
||||||
from torch.nn.modules import loss
|
from torch.nn.modules import loss
|
||||||
from model.hash_model import DCMHT as DCMHT
|
from model.hash_model import DCMHT as DCMHT
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import scipy.io as scio
|
import scipy.io as scio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .base import TrainBase
|
from .base import TrainBase
|
||||||
from torch.optim import Adam
|
from model.optimization import BertAdam
|
||||||
import torch.nn.functional as F
|
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
|
||||||
# from model.optimization import BertAdam
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||||
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
|
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
from dataset.dataloader import dataloader
|
||||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
import open_clip
|
||||||
from dataset.dataloader import dataloader
|
# from transformers import BertModel
|
||||||
import open_clip
|
|
||||||
# from transformers import BertModel
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def clamp(delta, clean_imgs):
|
def clamp(delta, clean_imgs):
|
||||||
|
|
||||||
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
|
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
|
||||||
clamp_delta = clamp_imgs - clean_imgs.data
|
clamp_delta = clamp_imgs - clean_imgs.data
|
||||||
|
|
||||||
return clamp_delta
|
return clamp_delta
|
||||||
|
|
||||||
|
class Trainer(TrainBase):
|
||||||
|
|
||||||
class Trainer(TrainBase):
|
def __init__(self,
|
||||||
|
rank=0):
|
||||||
def __init__(self,
|
args = get_args()
|
||||||
rank=0):
|
super(Trainer, self).__init__(args, rank)
|
||||||
args = get_args()
|
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||||
super(Trainer, self).__init__(args, rank)
|
text_representation, text_representation=self.generate_mapping()
|
||||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
self.image_representation=text_representation
|
||||||
text_mean_representation, text_var_representation=self.generate_mapping()
|
self.text_representation=text_representation
|
||||||
self.text_mean=text_mean_representation
|
self.device=rank
|
||||||
self.text_var=text_var_representation
|
# self.run()
|
||||||
# self.run()
|
|
||||||
|
def _init_model(self):
|
||||||
def _init_model(self):
|
self.logger.info("init model.")
|
||||||
self.logger.info("init model.")
|
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=device)
|
||||||
# self.generator=Generator()
|
self.model= model_clip
|
||||||
# linear = False
|
self.model.eval()
|
||||||
# if self.args.hash_layer == "linear":
|
self.model.float()
|
||||||
# linear = True
|
|
||||||
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
|
def _init_dataset(self):
|
||||||
# self.bert.eval()
|
self.logger.info("init dataset.")
|
||||||
# self.logger.info("ViT+GPT!")
|
self.logger.info(f"Using {self.args.dataset} dataset.")
|
||||||
# HashModel = DCMHT
|
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
||||||
# if self.args.victim_model == 'JDSH':
|
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
||||||
# from model.JDSH import TxtNet, ImgNet
|
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
||||||
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
||||||
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
indexFile=self.args.index_file,
|
||||||
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
|
labelFile=self.args.label_file,
|
||||||
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
|
maxWords=self.args.max_words,
|
||||||
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
|
imageResolution=self.args.resolution,
|
||||||
# checkpoint=torch.load(path)
|
query_num=self.args.query_num,
|
||||||
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
|
train_num=self.args.train_num,
|
||||||
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
|
seed=self.args.seed)
|
||||||
# self.img_model.eval()
|
self.train_labels = train_data.get_all_label()
|
||||||
# self.txt_model.eval()
|
self.query_labels = query_data.get_all_label()
|
||||||
# elif self.args.victim_model == 'DJSRH':
|
self.retrieval_labels = retrieval_data.get_all_label()
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
self.args.retrieval_num = len(self.retrieval_labels)
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
self.logger.info(f"query shape: {self.query_labels.shape}")
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
||||||
# elif self.args.victim_model == 'SSAH':
|
self.train_loader = DataLoader(
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
dataset=train_data,
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
batch_size=self.args.batch_size,
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
num_workers=self.args.num_workers,
|
||||||
# elif self.args.victim_model == 'DCHUC':
|
pin_memory=True,
|
||||||
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
shuffle=True
|
||||||
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
)
|
||||||
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
self.query_loader = DataLoader(
|
||||||
|
dataset=query_data,
|
||||||
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
|
batch_size=self.args.batch_size,
|
||||||
# self.logger.info("load pretrained model.")
|
num_workers=self.args.num_workers,
|
||||||
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
pin_memory=True,
|
||||||
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device)
|
shuffle=True
|
||||||
self.model= model_clip
|
)
|
||||||
self.model.eval()
|
self.retrieval_loader = DataLoader(
|
||||||
self.model.float()
|
dataset=retrieval_data,
|
||||||
self.optimizer =Adam(self.model.visual.parameters,lr=self.args.lr ,betas=[0.9,0.98] )
|
batch_size=self.args.batch_size,
|
||||||
# self.optimizer = BertAdam([
|
num_workers=self.args.num_workers,
|
||||||
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
pin_memory=True,
|
||||||
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
shuffle=True
|
||||||
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
|
)
|
||||||
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
|
|
||||||
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
|
|
||||||
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
|
|
||||||
|
def generate_mapping(self):
|
||||||
# print(self.model)
|
text_train=[]
|
||||||
|
label_train=[]
|
||||||
def _init_dataset(self):
|
for image, text, label, index in self.train_loader:
|
||||||
self.logger.info("init dataset.")
|
text=text.to(device, non_blocking=True)
|
||||||
self.logger.info(f"Using {self.args.dataset} dataset.")
|
# print(self.model.vocab_size)
|
||||||
self.args.index_file = os.path.join("./dataset", self.args.dataset, self.args.index_file)
|
temp_text=self.model.encode_text(text)
|
||||||
self.args.caption_file = os.path.join("./dataset", self.args.dataset, self.args.caption_file)
|
text_train.append(temp_text.cpu().detach().numpy())
|
||||||
self.args.label_file = os.path.join("./dataset", self.args.dataset, self.args.label_file)
|
label_train.append(label.detach().numpy())
|
||||||
train_data, query_data, retrieval_data = dataloader(captionFile=self.args.caption_file,
|
text_train=np.concatenate(text_train, axis=0)
|
||||||
indexFile=self.args.index_file,
|
label_train=np.concatenate(label_train, axis=0)
|
||||||
labelFile=self.args.label_file,
|
label_unipue=np.unique(label_train,axis=0)
|
||||||
maxWords=self.args.max_words,
|
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
imageResolution=self.args.resolution,
|
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
query_num=self.args.query_num,
|
|
||||||
train_num=self.args.train_num,
|
text_representation = {}
|
||||||
seed=self.args.seed)
|
text_var_representation = {}
|
||||||
self.train_labels = train_data.get_all_label()
|
for i, centroid in enumerate(label_unipue):
|
||||||
self.query_labels = query_data.get_all_label()
|
text_representation[centroid.tobytes()] = text_centroids[i]
|
||||||
self.retrieval_labels = retrieval_data.get_all_label()
|
text_var_representation[centroid.tobytes()]= text_var[i]
|
||||||
# self.args.retrieval_num = len(self.retrieval_labels)
|
return text_representation, text_var_representation
|
||||||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
|
||||||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
def target_adv(self, image, positive, negative,
|
||||||
self.train_loader = DataLoader(
|
epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||||
dataset=train_data,
|
|
||||||
batch_size=self.args.batch_size,
|
delta = torch.zeros_like(image,requires_grad=True)
|
||||||
num_workers=self.args.num_workers,
|
one=torch.zeros_like(positive)
|
||||||
pin_memory=True,
|
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||||
shuffle=True
|
for i in range(num_iter):
|
||||||
)
|
self.model.zero_grad()
|
||||||
self.query_loader = DataLoader(
|
anchor=self.model.encode_image(image+delta)
|
||||||
dataset=query_data,
|
loss=alienation_loss(anchor, positive, negative)
|
||||||
batch_size=self.args.batch_size,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
pin_memory=True,
|
loss.backward(retain_graph=True)
|
||||||
shuffle=True
|
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||||
)
|
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||||
self.retrieval_loader = DataLoader(
|
delta.grad.zero_()
|
||||||
dataset=retrieval_data,
|
|
||||||
batch_size=self.args.batch_size,
|
return delta.detach()
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
pin_memory=True,
|
def train_epoch(self, epoch):
|
||||||
shuffle=True
|
self.change_state(mode="valid")
|
||||||
)
|
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||||
def generate_mapping(self):
|
all_loss = 0
|
||||||
text_train=[]
|
times = 0
|
||||||
label_train=[]
|
adv_images=[]
|
||||||
# image_train=[]
|
adv_labels=[]
|
||||||
# self.change_state(mode="valid")
|
texts=[]
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
# image=image.to(self.device, non_blocking=True)
|
self.global_step += 1
|
||||||
text=text.to(self.device, non_blocking=True)
|
times += 1
|
||||||
temp_text=self.model.encode_text(text)
|
image.float()
|
||||||
# temp_image=self.model.encode_image(image)
|
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
||||||
# image_train.append(temp_image.cpu().detach().numpy())
|
label = torch.ones([image.shape[0]], dtype=torch.int)
|
||||||
text_train.append(temp_text.cpu().detach().numpy())
|
label = label.diag()
|
||||||
label_train.append(label.detach().numpy())
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text_train=np.concatenate(text_train, axis=0)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
# image_train=np.concatenate(image_train, axis=0)
|
index = index.numpy()
|
||||||
label_train=np.concatenate(label_train, axis=0)
|
image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||||
label_unipue=np.unique(label_train,axis=0)
|
text_anchor=self.text_representation(label.detach().cpu().numpy())
|
||||||
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||||
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
target_label=label.flip(dims=[0])
|
||||||
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||||
|
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
||||||
text_mean_representation = {}
|
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||||
text_var_representation = {}
|
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
||||||
for i, centroid in enumerate(label_unipue):
|
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
||||||
text_mean_representation[centroid.tobytes()] = text_centroids[i]
|
adv_image=delta+image
|
||||||
text_var_representation[centroid.tobytes()]= text_var[i]
|
adv_images.append(adv_image)
|
||||||
return text_mean_representation, text_var_representation
|
adv_labels.append(target_label)
|
||||||
|
texts.append(text)
|
||||||
def target_adv(self, image, positive, positive_mean,positive_var, negative, negative_mean, negative_var,
|
return adv_images, texts, adv_labels
|
||||||
epsilon=0.03125, alpha=3/255, num_iter=100):
|
|
||||||
|
|
||||||
delta = torch.zeros_like(image,requires_grad=True)
|
|
||||||
# clean_output = self.model.encode_image(image)
|
def train(self):
|
||||||
one=torch.zeros_like(positive)
|
self.logger.info("Start train.")
|
||||||
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
|
||||||
for i in range(num_iter):
|
for epoch in range(self.args.epochs):
|
||||||
self.model.zero_grad()
|
self.train_epoch(epoch)
|
||||||
anchor=self.model.encode_image(image+delta)
|
self.valid(epoch)
|
||||||
loss1=alienation_loss(anchor, positive, negative)
|
self.save_model(epoch)
|
||||||
loss=loss1 + self.args.beta * self.distribution_loss(anchor,positive_mean,positive_var,negative_mean, negative_var)
|
|
||||||
|
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||||
loss.backward(retain_graph=True)
|
|
||||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
||||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
|
||||||
delta.grad.zero_()
|
s = torch.matmul(a, b.t())
|
||||||
|
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
|
||||||
return delta.detach()
|
|
||||||
|
return b_loss
|
||||||
def train_epoch(self):
|
|
||||||
self.change_state(mode="valid")
|
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
||||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
"""
|
||||||
all_loss = 0
|
"""
|
||||||
times = 0
|
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
|
||||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
print("mean", torch.mean(a - b))
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
print("kl", kl_divergence)
|
||||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
return kl_divergence
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
||||||
adv_images=[]
|
|
||||||
adv_labels=[]
|
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
|
||||||
# target_texts=[]
|
|
||||||
for image, text, label, index in self.train_loader:
|
# $\vartheta$
|
||||||
self.global_step += 1
|
vartheta = self.args.vartheta
|
||||||
times += 1
|
if self.args.sim_threshold != 0:
|
||||||
image.float()
|
threshold = self.args.sim_threshold
|
||||||
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
|
||||||
label = torch.ones([image.shape[0]], dtype=torch.int)
|
|
||||||
label = label.diag()
|
positive_similarity = similarity * label_sim
|
||||||
image = image.to(self.rank, non_blocking=True)
|
# 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。
|
||||||
text = text.to(self.rank, non_blocking=True)
|
negative_similarity = similarity * (1 - label_sim)
|
||||||
index = index.numpy()
|
|
||||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
if self.args.similarity_function == "cosine":
|
||||||
negetive_mean=self.text_mean(label.detach().cpu().numpy())
|
positive_similarity = positive_similarity.clip(threshold) - threshold
|
||||||
negetive_var=self.text_var(label.detach().cpu().numpy())
|
negative_similarity = negative_similarity.clip(max=1.)
|
||||||
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
||||||
negetive_code=self.model.encode_text(text)
|
elif self.args.similarity_function == "euclidean":
|
||||||
target_label=label.flip(dims=[0])
|
# 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值
|
||||||
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
# 人为认定的最大值是一半长度的hash码不同。
|
||||||
positive_mean=self.text_mean(target_label.detach().cpu().numpy())
|
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
|
||||||
positive_var=self.text_var(target_label.detach().cpu().numpy())
|
negative_similarity = negative_similarity.clip(max=max_value)
|
||||||
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
||||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
|
||||||
# print("text shape:", text.shape)
|
if self.args.loss_type == "l1":
|
||||||
# index = index.numpy()
|
positive_loss = positive_similarity.mean()
|
||||||
# print(text.shape)
|
negative_loss = negative_similarity.mean()
|
||||||
delta=self.target_adv(image,positive_code,torch.from_numpy(positive_mean).to(self.rank, non_blocking=True), torch.from_numpy(positive_var).to(self.rank, non_blocking=True),
|
elif self.args.loss_type == "l2":
|
||||||
negetive_code, torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True), torch.from_numpy(negetive_var).to(self.rank, non_blocking=True))
|
positive_loss = torch.pow(positive_similarity, 2).mean()
|
||||||
adv_image=delta+image
|
negative_loss = torch.pow(negative_similarity, 2).mean()
|
||||||
adv_images.append(self.model.encode_image(adv_image))
|
else:
|
||||||
adv_labels.append(target_label)
|
raise ValueError("argument of loss_type is not support.")
|
||||||
# target_texts.append(self.model.encode_text(text))
|
|
||||||
adv_image=torch.cat(adv_image).to(self.device)
|
return similarity, positive_loss, negative_loss
|
||||||
adv_labels=torch.cat(adv_labels).to(self.device)
|
|
||||||
|
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||||
mAPi2t = calc_map_k(adv_images, retrieval_txt, adv_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPt2t = calc_map_k(adv_images, retrieval_img, adv_labels, self.retrieval_labels, None, self.rank)
|
code = torch.stack(code)
|
||||||
self.logger.info(f">>>>>> t-MAP(i->t): {mAPi2t}, t-MAP(t->t): {mAPt2t}")
|
# print(code.shape)
|
||||||
adv_images = adv_images.cpu().detach().numpy()
|
code = code.permute(1, 0, 2)
|
||||||
# query_txt = query_txt.cpu().detach().numpy()
|
hash_code = torch.argmax(code, dim=-1)
|
||||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
hash_code[torch.where(hash_code == 0)] = -1
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
hash_code = hash_code.float()
|
||||||
adv_labels = adv_labels.numpy()
|
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
return hash_code
|
||||||
|
|
||||||
result_dict = {
|
def get_code(self, data_loader, length: int):
|
||||||
'adv_img': adv_images,
|
|
||||||
# 'q_txt': query_txt,
|
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
'r_img': retrieval_img,
|
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
'r_txt': retrieval_txt,
|
|
||||||
'adv_l': adv_labels,
|
for image, text, label, index in tqdm(data_loader):
|
||||||
'r_l': retrieval_labels
|
image = image.to(self.rank, non_blocking=True)
|
||||||
}
|
text = text.to(self.rank, non_blocking=True)
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
index = index.numpy()
|
||||||
self.logger.info(">>>>>> save all data!")
|
image_hash=self.img_model(image)
|
||||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
text_feat=self.bert(text)[0]
|
||||||
# return adv_images, texts, adv_labels
|
text_hash=self.txt_model(text_feat)
|
||||||
|
img_buffer[index, :] = image_hash.data
|
||||||
|
text_buffer[index, :] = text_hash.data
|
||||||
|
|
||||||
def train(self):
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
self.logger.info("Start train.")
|
|
||||||
self.valid()
|
def get_adv_code(self, adv_data_list,text_list):
|
||||||
self.train_epoch()
|
|
||||||
# self.valid()
|
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
# for epoch in range(self.args.epochs):
|
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
# self.train_epoch()
|
|
||||||
# self.valid(epoch)
|
for i in tqdm(range(len(adv_data_list))):
|
||||||
# self.save_model(epoch)
|
image = adv_data_list[i].to(self.rank, non_blocking=True)
|
||||||
|
text = text_list[i].to(self.rank, non_blocking=True)
|
||||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
# index = index.numpy()
|
||||||
|
image_hash=self.img_model(image)
|
||||||
def bayesian_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
text_feat=self.bert(text)[0]
|
||||||
|
text_hash=self.txt_model(text_feat)
|
||||||
s = torch.matmul(a, b.t())
|
# text_hash = self.make_hash_code(text_hash)
|
||||||
b_loss = -torch.mean(label_sim * s - torch.log(1 + torch.exp(s)))
|
# image_hash.to(self.rank)
|
||||||
|
# text_hash.to(self.rank)
|
||||||
return b_loss
|
img_buffer[i, :] = image_hash.data
|
||||||
|
text_buffer[i, :] = text_hash.data
|
||||||
def distribution_loss(self, x: torch.Tensor, positive_mean,positive_var, negative_mean, negative_var):
|
|
||||||
"""
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
"""
|
|
||||||
norm_fun= lambda mean, var, x: 50- torch.mean(torch.exp(-(x-mean) **2 /(2*var)) /(2* torch.pi * var))
|
|
||||||
positive_distribution=norm_fun(positive_mean,positive_var,x)
|
def valid_attack(self,adv_images, texts, adv_labels):
|
||||||
negative_distribution=norm_fun(negative_mean,negative_var,x)
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||||
# alienation_loss=nn.MarginRankingLoss()
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
return F.margin_ranking_loss(positive_distribution,negative_distribution,1)
|
|
||||||
|
|
||||||
|
def test(self, mode_name="i2t"):
|
||||||
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
|
if self.args.pretrained == "":
|
||||||
|
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
||||||
# $\vartheta$
|
self.change_state(mode="valid")
|
||||||
vartheta = self.args.vartheta
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||||
if self.args.sim_threshold != 0:
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
threshold = self.args.sim_threshold
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||||
similarity = (1 - cosine_similarity(a, b)) if self.args.similarity_function == "cosine" else euclidean_similarity(a, b)
|
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
|
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
positive_similarity = similarity * label_sim
|
# print("map map")
|
||||||
# 只要cosine为负值的全都算为计算正确了,因为优化到2确实很难。
|
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
negative_similarity = similarity * (1 - label_sim)
|
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
|
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
if self.args.similarity_function == "cosine":
|
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||||
positive_similarity = positive_similarity.clip(threshold) - threshold
|
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
|
||||||
negative_similarity = negative_similarity.clip(max=1.)
|
|
||||||
negative_similarity = torch.tensor([1.]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
query_img = query_img.cpu().detach().numpy()
|
||||||
elif self.args.similarity_function == "euclidean":
|
query_txt = query_txt.cpu().detach().numpy()
|
||||||
# 有euclidean距离可知,当有一半长度的hash码不同时,其negative_similarity距离应该是长度(concat操作将outputdim翻倍),所以这里clip掉认为认定的值
|
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||||
# 人为认定的最大值是一半长度的hash码不同。
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
max_value = float(self.args.output_dim * 2 * vartheta) ** 0.5
|
query_labels = self.query_labels.numpy()
|
||||||
negative_similarity = negative_similarity.clip(max=max_value)
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
negative_similarity = torch.tensor([max_value]).expand_as(negative_similarity).to(self.rank) * (1 - label_sim) - negative_similarity
|
|
||||||
|
result_dict = {
|
||||||
if self.args.loss_type == "l1":
|
'q_img': query_img,
|
||||||
positive_loss = positive_similarity.mean()
|
'q_txt': query_txt,
|
||||||
negative_loss = negative_similarity.mean()
|
'r_img': retrieval_img,
|
||||||
elif self.args.loss_type == "l2":
|
'r_txt': retrieval_txt,
|
||||||
positive_loss = torch.pow(positive_similarity, 2).mean()
|
'q_l': query_labels,
|
||||||
negative_loss = torch.pow(negative_similarity, 2).mean()
|
'r_l': retrieval_labels
|
||||||
else:
|
}
|
||||||
raise ValueError("argument of loss_type is not support.")
|
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||||
|
self.logger.info(">>>>>> save all data!")
|
||||||
return similarity, positive_loss, negative_loss
|
|
||||||
|
|
||||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
def valid(self, epoch):
|
||||||
|
self.logger.info("Valid.")
|
||||||
code = torch.stack(code)
|
self.change_state(mode="valid")
|
||||||
# print(code.shape)
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||||
code = code.permute(1, 0, 2)
|
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
hash_code = torch.argmax(code, dim=-1)
|
# print("get all code")
|
||||||
hash_code[torch.where(hash_code == 0)] = -1
|
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
hash_code = hash_code.float()
|
# print("map map")
|
||||||
|
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
return hash_code
|
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
|
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
def get_code(self, data_loader, length: int):
|
if self.max_mapi2t < mAPi2t:
|
||||||
|
self.best_epoch_i = epoch
|
||||||
img_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||||
text_buffer = torch.empty(length, self.args.output_dim, dtype=torch.float).to(self.rank)
|
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||||
|
if self.max_mapt2i < mAPt2i:
|
||||||
for image, text, label, index in tqdm(data_loader):
|
self.best_epoch_t = epoch
|
||||||
image = image.to(self.rank, non_blocking=True)
|
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||||
text = text.to(self.rank, non_blocking=True)
|
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||||
index = index.numpy()
|
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||||
image_hash=self.model.encode_image(image)
|
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||||
# text_feat=self.bert(text)[0]
|
|
||||||
text_hash=self.model.encode_text(text)
|
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||||
img_buffer[index, :] = image_hash.data
|
|
||||||
text_buffer[index, :] = text_hash.data
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
|
||||||
|
query_img = query_img.cpu().detach().numpy()
|
||||||
def get_adv_code(self, adv_data_list,text_list):
|
query_txt = query_txt.cpu().detach().numpy()
|
||||||
|
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||||
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
query_labels = self.query_labels.numpy()
|
||||||
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
for i in tqdm(range(len(adv_data_list))):
|
|
||||||
image = adv_data_list[i].to(self.rank, non_blocking=True)
|
result_dict = {
|
||||||
text = text_list[i].to(self.rank, non_blocking=True)
|
'q_img': query_img,
|
||||||
# index = index.numpy()
|
'q_txt': query_txt,
|
||||||
image_hash=self.img_model(image)
|
'r_img': retrieval_img,
|
||||||
text_feat=self.bert(text)[0]
|
'r_txt': retrieval_txt,
|
||||||
text_hash=self.txt_model(text_feat)
|
'q_l': query_labels,
|
||||||
# text_hash = self.make_hash_code(text_hash)
|
'r_l': retrieval_labels
|
||||||
# image_hash.to(self.rank)
|
}
|
||||||
# text_hash.to(self.rank)
|
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||||
img_buffer[i, :] = image_hash.data
|
self.logger.info(f">>>>>> save best {mode_name} data!")
|
||||||
text_buffer[i, :] = text_hash.data
|
|
||||||
|
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
|
||||||
|
|
||||||
|
|
||||||
def valid_attack(self,adv_images, texts, adv_labels):
|
|
||||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test(self, mode_name="i2t"):
|
|
||||||
if self.args.pretrained == "":
|
|
||||||
raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
|
||||||
self.change_state(mode="valid")
|
|
||||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
||||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# print("map map")
|
|
||||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
|
||||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
|
|
||||||
|
|
||||||
query_img = query_img.cpu().detach().numpy()
|
|
||||||
query_txt = query_txt.cpu().detach().numpy()
|
|
||||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
|
||||||
query_labels = self.query_labels.numpy()
|
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
'q_img': query_img,
|
|
||||||
'q_txt': query_txt,
|
|
||||||
'r_img': retrieval_img,
|
|
||||||
'r_txt': retrieval_txt,
|
|
||||||
'q_l': query_labels,
|
|
||||||
'r_l': retrieval_labels
|
|
||||||
}
|
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
|
||||||
self.logger.info(">>>>>> save all data!")
|
|
||||||
|
|
||||||
|
|
||||||
def valid(self, epoch):
|
|
||||||
self.logger.info("Valid.")
|
|
||||||
# self.change_state(mode="valid")
|
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
|
||||||
# print("get all code")
|
|
||||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# print("map map")
|
|
||||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
if self.max_mapi2t < mAPi2t:
|
|
||||||
self.best_epoch_i = epoch
|
|
||||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
|
||||||
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
|
||||||
if self.max_mapt2i < mAPt2i:
|
|
||||||
self.best_epoch_t = epoch
|
|
||||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
|
||||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
|
||||||
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
|
||||||
|
|
||||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
|
||||||
|
|
||||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
query_img = query_img.cpu().detach().numpy()
|
|
||||||
query_txt = query_txt.cpu().detach().numpy()
|
|
||||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
|
||||||
query_labels = self.query_labels.numpy()
|
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
'q_img': query_img,
|
|
||||||
'q_txt': query_txt,
|
|
||||||
'r_img': retrieval_img,
|
|
||||||
'r_txt': retrieval_txt,
|
|
||||||
'q_l': query_labels,
|
|
||||||
'r_l': retrieval_labels
|
|
||||||
}
|
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
|
||||||
self.logger.info(f">>>>>> save best {mode_name} data!")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue