new dist
This commit is contained in:
parent
c479c07a2c
commit
b482a0e942
4
main.py
4
main.py
|
|
@ -5,6 +5,8 @@ if __name__ == "__main__":
|
|||
|
||||
engine=Trainer()
|
||||
engine.test()
|
||||
engine.train()
|
||||
adv_images, texts, adv_labels= engine.train_epoch()
|
||||
|
||||
# engine.train()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from torch.nn.modules import loss
|
||||
from model.hash_model import DCMHT as DCMHT
|
||||
# from model.hash_model import DCMHT as DCMHT
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
|
@ -33,9 +33,9 @@ class Trainer(TrainBase):
|
|||
args = get_args()
|
||||
super(Trainer, self).__init__(args, rank)
|
||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||
text_representation, text_representation=self.generate_mapping()
|
||||
self.image_representation=text_representation
|
||||
self.text_representation=text_representation
|
||||
text_mean, text_var=self.generate_mapping()
|
||||
self.text_mean=text_mean
|
||||
self.text_var=text_var
|
||||
self.device=rank
|
||||
# self.run()
|
||||
|
||||
|
|
@ -112,18 +112,19 @@ class Trainer(TrainBase):
|
|||
text_var_representation[centroid.tobytes()]= text_var[i]
|
||||
return text_representation, text_var_representation
|
||||
|
||||
def target_adv(self, image, positive, negative,
|
||||
epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
one=torch.zeros_like(positive)
|
||||
# one=torch.zeros_like(positive)
|
||||
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
loss=alienation_loss(anchor, positive, negative)
|
||||
|
||||
|
||||
loss1=alienation_loss(anchor, positive_code, negetive_code)
|
||||
negative_dist=((anchor-negetive_mean))**2 / negative_var
|
||||
positive_dist=((anchor-positive_mean))**2 /positive_var
|
||||
loss= positive_dist -negative_dist + beta* loss1
|
||||
loss.backward(retain_graph=True)
|
||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||
|
|
@ -131,38 +132,36 @@ class Trainer(TrainBase):
|
|||
|
||||
return delta.detach()
|
||||
|
||||
# def train_epoch(self, epoch):
|
||||
# self.change_state(mode="valid")
|
||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
# all_loss = 0
|
||||
# times = 0
|
||||
# adv_images=[]
|
||||
# adv_labels=[]
|
||||
# texts=[]
|
||||
# for image, text, label, index in self.train_loader:
|
||||
# self.global_step += 1
|
||||
# times += 1
|
||||
# image.float()
|
||||
# if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
||||
# label = torch.ones([image.shape[0]], dtype=torch.int)
|
||||
# label = label.diag()
|
||||
# image = image.to(self.rank, non_blocking=True)
|
||||
# text = text.to(self.rank, non_blocking=True)
|
||||
# index = index.numpy()
|
||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
# text_anchor=self.text_representation(label.detach().cpu().numpy())
|
||||
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||
# target_label=label.flip(dims=[0])
|
||||
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||
# target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
||||
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||
# delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
||||
# torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
||||
# adv_image=delta+image
|
||||
# adv_images.append(adv_image)
|
||||
# adv_labels.append(target_label)
|
||||
# texts.append(text)
|
||||
# return adv_images, texts, adv_labels
|
||||
def train_epoch(self):
|
||||
self.change_state(mode="valid")
|
||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
all_loss = 0
|
||||
times = 0
|
||||
adv_images=[]
|
||||
adv_labels=[]
|
||||
texts=[]
|
||||
for image, text, label, index in self.train_loader:
|
||||
self.global_step += 1
|
||||
times += 1
|
||||
image.float()
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
negetive_mean=torch.cat([self.text_mean[i] for i in label.detach().cpu().numpy()])
|
||||
negative_var=torch.cat([self.text_var[i] for i in label.detach().cpu().numpy()])
|
||||
negetive_code=self.model.encode_text(text)
|
||||
target_label=label.flip(dims=[0])
|
||||
positive_mean=negetive_mean.flip(dims=[0])
|
||||
positive_var=negative_var.flip(dims=[0])
|
||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
||||
delta=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
positive_code,positive_mean,positive_var)
|
||||
adv_image=delta+image
|
||||
adv_images.append(adv_image)
|
||||
adv_labels.append(target_label)
|
||||
texts.append(text)
|
||||
return adv_images, texts, adv_labels
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,15 @@ def calc_hammingDist(B1, B2):
|
|||
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
|
||||
return distH
|
||||
|
||||
def calc_cosineDist(B1, B2):
|
||||
q = B2.shape[1]
|
||||
if len(B1.shape) < 2:
|
||||
B1 = B1.unsqueeze(0)
|
||||
dot_product = B1.mm(B2.transpose(0, 1))
|
||||
norms = torch.sqrt(torch.einsum('ii->i', dot_product))
|
||||
distH = dot_product/(norms[None]*norms[..., None])
|
||||
return distH
|
||||
|
||||
|
||||
def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue