update hash_train.py
This commit is contained in:
parent
e79be260b4
commit
97418d42e1
|
|
@ -47,12 +47,12 @@ class TrainBase(object):
|
|||
else:
|
||||
self.test()
|
||||
|
||||
def change_state(self, mode):
|
||||
# def change_state(self, mode):
|
||||
|
||||
if mode == "train":
|
||||
self.model.train()
|
||||
elif mode == "valid":
|
||||
self.model.eval()
|
||||
# if mode == "train":
|
||||
# self.model.train()
|
||||
# elif mode == "valid":
|
||||
# self.model.eval()
|
||||
|
||||
def get_code(self, data_loader, length: int):
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ import scipy.io as scio
|
|||
import numpy as np
|
||||
|
||||
from .base import TrainBase
|
||||
from model.optimization import BertAdam
|
||||
from torch.optim import Adam
|
||||
import torch.nn.functional as F
|
||||
# from model.optimization import BertAdam
|
||||
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
|
||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||
|
|
@ -24,6 +26,8 @@ def clamp(delta, clean_imgs):
|
|||
|
||||
return clamp_delta
|
||||
|
||||
|
||||
|
||||
class Trainer(TrainBase):
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -31,9 +35,9 @@ class Trainer(TrainBase):
|
|||
args = get_args()
|
||||
super(Trainer, self).__init__(args, rank)
|
||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||
image_representation, text_representation=self.generate_mapping()
|
||||
self.image_representation=image_representation
|
||||
self.text_representation=text_representation
|
||||
text_mean_representation, text_var_representation=self.generate_mapping()
|
||||
self.text_mean=text_mean_representation
|
||||
self.text_var=text_var_representation
|
||||
# self.run()
|
||||
|
||||
def _init_model(self):
|
||||
|
|
@ -78,6 +82,7 @@ class Trainer(TrainBase):
|
|||
self.model= model_clip
|
||||
self.model.eval()
|
||||
self.model.float()
|
||||
self.optimizer =Adam(self.model.visual.parameters,lr=self.args.lr ,betas=[0.9,0.98] )
|
||||
# self.optimizer = BertAdam([
|
||||
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
||||
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
||||
|
|
@ -105,7 +110,7 @@ class Trainer(TrainBase):
|
|||
self.train_labels = train_data.get_all_label()
|
||||
self.query_labels = query_data.get_all_label()
|
||||
self.retrieval_labels = retrieval_data.get_all_label()
|
||||
self.args.retrieval_num = len(self.retrieval_labels)
|
||||
# self.args.retrieval_num = len(self.retrieval_labels)
|
||||
self.logger.info(f"query shape: {self.query_labels.shape}")
|
||||
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
|
||||
self.train_loader = DataLoader(
|
||||
|
|
@ -150,14 +155,14 @@ class Trainer(TrainBase):
|
|||
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||
|
||||
text_representation = {}
|
||||
text_mean_representation = {}
|
||||
text_var_representation = {}
|
||||
for i, centroid in enumerate(label_unipue):
|
||||
text_representation[centroid.tobytes()] = text_centroids[i]
|
||||
text_mean_representation[centroid.tobytes()] = text_centroids[i]
|
||||
text_var_representation[centroid.tobytes()]= text_var[i]
|
||||
return text_representation, text_var_representation
|
||||
return text_mean_representation, text_var_representation
|
||||
|
||||
def target_adv(self, image, positive, negative,
|
||||
def target_adv(self, image, positive, positive_mean,positive_var, negative, negative_mean, negative_var,
|
||||
epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
|
|
@ -167,8 +172,8 @@ class Trainer(TrainBase):
|
|||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
loss=alienation_loss(anchor, positive, negative)
|
||||
|
||||
loss1=alienation_loss(anchor, positive, negative)
|
||||
loss=loss1 + self.args.beta * self.distribution_loss(anchor,positive_mean,positive_var,negative_mean, negative_var)
|
||||
|
||||
loss.backward(retain_graph=True)
|
||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||
|
|
@ -177,14 +182,18 @@ class Trainer(TrainBase):
|
|||
|
||||
return delta.detach()
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
def train_epoch(self):
|
||||
self.change_state(mode="valid")
|
||||
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
all_loss = 0
|
||||
times = 0
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
adv_images=[]
|
||||
adv_labels=[]
|
||||
texts=[]
|
||||
# target_texts=[]
|
||||
for image, text, label, index in self.train_loader:
|
||||
self.global_step += 1
|
||||
times += 1
|
||||
|
|
@ -195,34 +204,63 @@ class Trainer(TrainBase):
|
|||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
text_anchor=self.text_representation(label.detach().cpu().numpy())
|
||||
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
negetive_mean=self.text_mean(label.detach().cpu().numpy())
|
||||
negetive_var=self.text_var(label.detach().cpu().numpy())
|
||||
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||
negetive_code=self.model.encode_text(text)
|
||||
target_label=label.flip(dims=[0])
|
||||
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
||||
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||
positive_mean=self.text_mean(target_label.detach().cpu().numpy())
|
||||
positive_var=self.text_var(target_label.detach().cpu().numpy())
|
||||
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
||||
# print("text shape:", text.shape)
|
||||
# index = index.numpy()
|
||||
# print(text.shape)
|
||||
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
||||
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
||||
delta=self.target_adv(image,positive_code,torch.from_numpy(positive_mean).to(self.rank, non_blocking=True), torch.from_numpy(positive_var).to(self.rank, non_blocking=True),
|
||||
negetive_code, torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True), torch.from_numpy(negetive_var).to(self.rank, non_blocking=True))
|
||||
adv_image=delta+image
|
||||
adv_images.append(adv_image)
|
||||
adv_images.append(self.model.encode_image(adv_image))
|
||||
adv_labels.append(target_label)
|
||||
texts.append(text)
|
||||
# target_texts.append(self.model.encode_text(text))
|
||||
adv_image=torch.cat(adv_image).to(self.device)
|
||||
adv_labels=torch.cat(adv_labels).to(self.device)
|
||||
|
||||
mAPi2t = calc_map_k(adv_images, retrieval_txt, adv_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPt2t = calc_map_k(adv_images, retrieval_img, adv_labels, self.retrieval_labels, None, self.rank)
|
||||
self.logger.info(f">>>>>> t-MAP(i->t): {mAPi2t}, t-MAP(t->t): {mAPt2t}")
|
||||
adv_images = adv_images.cpu().detach().numpy()
|
||||
# query_txt = query_txt.cpu().detach().numpy()
|
||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
adv_labels = adv_labels.numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
result_dict = {
|
||||
'adv_img': adv_images,
|
||||
# 'q_txt': query_txt,
|
||||
'r_img': retrieval_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'adv_l': adv_labels,
|
||||
'r_l': retrieval_labels
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
||||
self.logger.info(">>>>>> save all data!")
|
||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
||||
return adv_images, texts, adv_labels
|
||||
# return adv_images, texts, adv_labels
|
||||
|
||||
|
||||
|
||||
def train(self):
|
||||
self.logger.info("Start train.")
|
||||
|
||||
for epoch in range(self.args.epochs):
|
||||
self.train_epoch(epoch)
|
||||
self.valid(epoch)
|
||||
self.save_model(epoch)
|
||||
self.valid()
|
||||
self.train_epoch()
|
||||
# self.valid()
|
||||
# for epoch in range(self.args.epochs):
|
||||
# self.train_epoch()
|
||||
# self.valid(epoch)
|
||||
# self.save_model(epoch)
|
||||
|
||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||
|
||||
|
|
@ -233,13 +271,15 @@ class Trainer(TrainBase):
|
|||
|
||||
return b_loss
|
||||
|
||||
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
|
||||
def distribution_loss(self, x: torch.Tensor, positive_mean,positive_var, negative_mean, negative_var):
|
||||
"""
|
||||
"""
|
||||
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
|
||||
print("mean", torch.mean(a - b))
|
||||
print("kl", kl_divergence)
|
||||
return kl_divergence
|
||||
norm_fun= lambda mean, var, x: 50- torch.mean(torch.exp(-(x-mean) **2 /(2*var)) /(2* torch.pi * var))
|
||||
positive_distribution=norm_fun(positive_mean,positive_var,x)
|
||||
negative_distribution=norm_fun(negative_mean,negative_var,x)
|
||||
# alienation_loss=nn.MarginRankingLoss()
|
||||
|
||||
return F.margin_ranking_loss(positive_distribution,negative_distribution,1)
|
||||
|
||||
|
||||
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
|
||||
|
|
@ -296,9 +336,9 @@ class Trainer(TrainBase):
|
|||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
image_hash=self.img_model(image)
|
||||
text_feat=self.bert(text)[0]
|
||||
text_hash=self.txt_model(text_feat)
|
||||
image_hash=self.model.encode_image(image)
|
||||
# text_feat=self.bert(text)[0]
|
||||
text_hash=self.model.encode_text(text)
|
||||
img_buffer[index, :] = image_hash.data
|
||||
text_buffer[index, :] = text_hash.data
|
||||
|
||||
|
|
@ -368,7 +408,7 @@ class Trainer(TrainBase):
|
|||
|
||||
def valid(self, epoch):
|
||||
self.logger.info("Valid.")
|
||||
self.change_state(mode="valid")
|
||||
# self.change_state(mode="valid")
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# print("get all code")
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@ import torch
|
|||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
def cosine_distance_torch(x1, x2=None, eps=1e-8):
|
||||
x2 = x1 if x2 is None else x2
|
||||
w1 = x1.norm(p=2, dim=1, keepdim=True)
|
||||
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
|
||||
return 1 - torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
|
||||
|
||||
def calc_hammingDist(B1, B2):
|
||||
q = B2.shape[1]
|
||||
|
|
@ -22,7 +27,8 @@ def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
|||
k = retrieval_L.shape[0]
|
||||
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
tsums = torch.sum(gnds, dim=-1, keepdim=True, dtype=torch.int32)
|
||||
hamms = calc_hammingDist(qB, rB)
|
||||
# hamms=torch.dist()
|
||||
hamms = cosine_distance_torch(qB, rB)
|
||||
_, ind = torch.sort(hamms, dim=-1)
|
||||
|
||||
totals = torch.min(tsums, torch.tensor([k], dtype=torch.int32).expand_as(tsums))
|
||||
|
|
@ -53,7 +59,7 @@ def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
|||
tsum = torch.sum(gnd)
|
||||
if tsum == 0:
|
||||
continue
|
||||
hamm = calc_hammingDist(qB[iter, :], rB)
|
||||
hamm = cosine_distance_torch(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
|
|
@ -84,7 +90,7 @@ def calc_precisions_topn_matrix(qB, rB, query_L, retrieval_L, recall_gas=0.02, n
|
|||
# num_retrieval = retrieval_L.shape[0]
|
||||
precisions = [0] * int(1 / recall_gas)
|
||||
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
hamms = calc_hammingDist(qB, rB)
|
||||
hamms = cosine_distance_torch(qB, rB)
|
||||
_, inds = torch.sort(hamms, dim=-1)
|
||||
for iter in range(num_query):
|
||||
gnd = gnds[iter]
|
||||
|
|
@ -115,7 +121,7 @@ def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retr
|
|||
if len(q_L.shape) < 2:
|
||||
q_L = q_L.unsqueeze(0) # [1, hash length]
|
||||
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
|
||||
hamm = calc_hammingDist(qB[iter, :], rB)
|
||||
hamm = cosine_distance_torch(qB[iter, :], rB)
|
||||
_, ind = torch.sort(hamm)
|
||||
ind.squeeze_()
|
||||
gnd = gnd[ind]
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ def get_args():
|
|||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||
parser.add_argument("--display-step", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=1814)
|
||||
parser.add_argument("--beta", type=float, default=0.24)
|
||||
|
||||
parser.add_argument("--lr", type=float, default=0.001)
|
||||
parser.add_argument("--lr-decay", type=float, default=0.9)
|
||||
|
|
|
|||
Loading…
Reference in New Issue