update hash_train.py

This commit is contained in:
Li Wenyun 2024-05-20 19:16:47 +08:00
parent e79be260b4
commit 97418d42e1
4 changed files with 95 additions and 48 deletions

View File

@ -47,12 +47,12 @@ class TrainBase(object):
else:
self.test()
def change_state(self, mode):
# def change_state(self, mode):
if mode == "train":
self.model.train()
elif mode == "valid":
self.model.eval()
# if mode == "train":
# self.model.train()
# elif mode == "valid":
# self.model.eval()
def get_code(self, data_loader, length: int):

View File

@ -9,7 +9,9 @@ import scipy.io as scio
import numpy as np
from .base import TrainBase
from model.optimization import BertAdam
from torch.optim import Adam
import torch.nn.functional as F
# from model.optimization import BertAdam
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k
@ -24,6 +26,8 @@ def clamp(delta, clean_imgs):
return clamp_delta
class Trainer(TrainBase):
def __init__(self,
@ -31,9 +35,9 @@ class Trainer(TrainBase):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
image_representation, text_representation=self.generate_mapping()
self.image_representation=image_representation
self.text_representation=text_representation
text_mean_representation, text_var_representation=self.generate_mapping()
self.text_mean=text_mean_representation
self.text_var=text_var_representation
# self.run()
def _init_model(self):
@ -78,6 +82,7 @@ class Trainer(TrainBase):
self.model= model_clip
self.model.eval()
self.model.float()
self.optimizer =Adam(self.model.visual.parameters,lr=self.args.lr ,betas=[0.9,0.98] )
# self.optimizer = BertAdam([
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
@ -105,7 +110,7 @@ class Trainer(TrainBase):
self.train_labels = train_data.get_all_label()
self.query_labels = query_data.get_all_label()
self.retrieval_labels = retrieval_data.get_all_label()
self.args.retrieval_num = len(self.retrieval_labels)
# self.args.retrieval_num = len(self.retrieval_labels)
self.logger.info(f"query shape: {self.query_labels.shape}")
self.logger.info(f"retrieval shape: {self.retrieval_labels.shape}")
self.train_loader = DataLoader(
@ -150,14 +155,14 @@ class Trainer(TrainBase):
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
text_representation = {}
text_mean_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[centroid.tobytes()] = text_centroids[i]
text_mean_representation[centroid.tobytes()] = text_centroids[i]
text_var_representation[centroid.tobytes()]= text_var[i]
return text_representation, text_var_representation
return text_mean_representation, text_var_representation
def target_adv(self, image, positive, negative,
def target_adv(self, image, positive, positive_mean,positive_var, negative, negative_mean, negative_var,
epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
@ -167,8 +172,8 @@ class Trainer(TrainBase):
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss=alienation_loss(anchor, positive, negative)
loss1=alienation_loss(anchor, positive, negative)
loss=loss1 + self.args.beta * self.distribution_loss(anchor,positive_mean,positive_var,negative_mean, negative_var)
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
@ -177,14 +182,18 @@ class Trainer(TrainBase):
return delta.detach()
def train_epoch(self, epoch):
def train_epoch(self):
self.change_state(mode="valid")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0
times = 0
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
adv_images=[]
adv_labels=[]
texts=[]
# target_texts=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
@ -195,34 +204,63 @@ class Trainer(TrainBase):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_anchor=self.image_representation(label.detach().cpu().numpy())
text_anchor=self.text_representation(label.detach().cpu().numpy())
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
# image_anchor=self.image_representation(label.detach().cpu().numpy())
negetive_mean=self.text_mean(label.detach().cpu().numpy())
negetive_var=self.text_var(label.detach().cpu().numpy())
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
negetive_code=self.model.encode_text(text)
target_label=label.flip(dims=[0])
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
positive_mean=self.text_mean(target_label.detach().cpu().numpy())
positive_var=self.text_var(target_label.detach().cpu().numpy())
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
positive_code=self.model.encode_text(text.flip(dims=[0]))
# print("text shape:", text.shape)
# index = index.numpy()
# print(text.shape)
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
delta=self.target_adv(image,positive_code,torch.from_numpy(positive_mean).to(self.rank, non_blocking=True), torch.from_numpy(positive_var).to(self.rank, non_blocking=True),
negetive_code, torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True), torch.from_numpy(negetive_var).to(self.rank, non_blocking=True))
adv_image=delta+image
adv_images.append(adv_image)
adv_images.append(self.model.encode_image(adv_image))
adv_labels.append(target_label)
texts.append(text)
# target_texts.append(self.model.encode_text(text))
adv_image=torch.cat(adv_image).to(self.device)
adv_labels=torch.cat(adv_labels).to(self.device)
mAPi2t = calc_map_k(adv_images, retrieval_txt, adv_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(adv_images, retrieval_img, adv_labels, self.retrieval_labels, None, self.rank)
self.logger.info(f">>>>>> t-MAP(i->t): {mAPi2t}, t-MAP(t->t): {mAPt2t}")
adv_images = adv_images.cpu().detach().numpy()
# query_txt = query_txt.cpu().detach().numpy()
retrieval_img = retrieval_img.cpu().detach().numpy()
retrieval_txt = retrieval_txt.cpu().detach().numpy()
adv_labels = adv_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy()
result_dict = {
'adv_img': adv_images,
# 'q_txt': query_txt,
'r_img': retrieval_img,
'r_txt': retrieval_txt,
'adv_l': adv_labels,
'r_l': retrieval_labels
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return adv_images, texts, adv_labels
# return adv_images, texts, adv_labels
def train(self):
self.logger.info("Start train.")
for epoch in range(self.args.epochs):
self.train_epoch(epoch)
self.valid(epoch)
self.save_model(epoch)
self.valid()
self.train_epoch()
# self.valid()
# for epoch in range(self.args.epochs):
# self.train_epoch()
# self.valid(epoch)
# self.save_model(epoch)
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
@ -233,13 +271,15 @@ class Trainer(TrainBase):
return b_loss
def distribution_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor):
def distribution_loss(self, x: torch.Tensor, positive_mean,positive_var, negative_mean, negative_var):
"""
"""
kl_divergence = torch.mean(a * torch.log(a / (b + 0.001)))
print("mean", torch.mean(a - b))
print("kl", kl_divergence)
return kl_divergence
norm_fun= lambda mean, var, x: 50- torch.mean(torch.exp(-(x-mean) **2 /(2*var)) /(2* torch.pi * var))
positive_distribution=norm_fun(positive_mean,positive_var,x)
negative_distribution=norm_fun(negative_mean,negative_var,x)
# alienation_loss=nn.MarginRankingLoss()
return F.margin_ranking_loss(positive_distribution,negative_distribution,1)
def similarity_loss(self, a: torch.Tensor, b: torch.Tensor, label_sim: torch.Tensor, threshold=0.05):
@ -296,9 +336,9 @@ class Trainer(TrainBase):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
image_hash=self.model.encode_image(image)
# text_feat=self.bert(text)[0]
text_hash=self.model.encode_text(text)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
@ -368,7 +408,7 @@ class Trainer(TrainBase):
def valid(self, epoch):
self.logger.info("Valid.")
self.change_state(mode="valid")
# self.change_state(mode="valid")
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# print("get all code")

View File

@ -2,6 +2,11 @@ import torch
import numpy as np
from tqdm import tqdm
def cosine_distance_torch(x1, x2=None, eps=1e-8):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
return 1 - torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
def calc_hammingDist(B1, B2):
q = B2.shape[1]
@ -22,7 +27,8 @@ def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):
k = retrieval_L.shape[0]
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsums = torch.sum(gnds, dim=-1, keepdim=True, dtype=torch.int32)
hamms = calc_hammingDist(qB, rB)
# hamms=torch.dist()
hamms = cosine_distance_torch(qB, rB)
_, ind = torch.sort(hamms, dim=-1)
totals = torch.min(tsums, torch.tensor([k], dtype=torch.int32).expand_as(tsums))
@ -53,7 +59,7 @@ def calc_map_k(qB, rB, query_L, retrieval_L, k=None, rank=0):
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
hamm = cosine_distance_torch(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
@ -84,7 +90,7 @@ def calc_precisions_topn_matrix(qB, rB, query_L, retrieval_L, recall_gas=0.02, n
# num_retrieval = retrieval_L.shape[0]
precisions = [0] * int(1 / recall_gas)
gnds = (query_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamms = calc_hammingDist(qB, rB)
hamms = cosine_distance_torch(qB, rB)
_, inds = torch.sort(hamms, dim=-1)
for iter in range(num_query):
gnd = gnds[iter]
@ -115,7 +121,7 @@ def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retr
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamm = calc_hammingDist(qB[iter, :], rB)
hamm = cosine_distance_torch(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]

View File

@ -30,6 +30,7 @@ def get_args():
parser.add_argument("--lr-decay-freq", type=int, default=5)
parser.add_argument("--display-step", type=int, default=50)
parser.add_argument("--seed", type=int, default=1814)
parser.add_argument("--beta", type=float, default=0.24)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--lr-decay", type=float, default=0.9)