This commit is contained in:
leewlving 2024-06-08 18:13:15 +08:00
parent c479c07a2c
commit b482a0e942
3 changed files with 53 additions and 43 deletions

View File

@ -5,6 +5,8 @@ if __name__ == "__main__":
engine=Trainer()
engine.test()
engine.train()
adv_images, texts, adv_labels= engine.train_epoch()
# engine.train()

View File

@ -1,5 +1,5 @@
from torch.nn.modules import loss
from model.hash_model import DCMHT as DCMHT
# from model.hash_model import DCMHT as DCMHT
import os
from tqdm import tqdm
import torch
@ -33,9 +33,9 @@ class Trainer(TrainBase):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
text_representation, text_representation=self.generate_mapping()
self.image_representation=text_representation
self.text_representation=text_representation
text_mean, text_var=self.generate_mapping()
self.text_mean=text_mean
self.text_var=text_var
self.device=rank
# self.run()
@ -112,18 +112,19 @@ class Trainer(TrainBase):
text_var_representation[centroid.tobytes()]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, positive, negative,
epsilon=0.03125, alpha=3/255, num_iter=100):
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
one=torch.zeros_like(positive)
# one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss=alienation_loss(anchor, positive, negative)
loss1=alienation_loss(anchor, positive_code, negetive_code)
negative_dist=((anchor-negetive_mean))**2 / negative_var
positive_dist=((anchor-positive_mean))**2 /positive_var
loss= positive_dist -negative_dist + beta* loss1
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
@ -131,38 +132,36 @@ class Trainer(TrainBase):
return delta.detach()
# def train_epoch(self, epoch):
# self.change_state(mode="valid")
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
# all_loss = 0
# times = 0
# adv_images=[]
# adv_labels=[]
# texts=[]
# for image, text, label, index in self.train_loader:
# self.global_step += 1
# times += 1
# image.float()
# if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
# label = torch.ones([image.shape[0]], dtype=torch.int)
# label = label.diag()
# image = image.to(self.rank, non_blocking=True)
# text = text.to(self.rank, non_blocking=True)
# index = index.numpy()
# image_anchor=self.image_representation(label.detach().cpu().numpy())
# text_anchor=self.text_representation(label.detach().cpu().numpy())
# negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
# target_label=label.flip(dims=[0])
# target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
# target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
# positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
# torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
# adv_image=delta+image
# adv_images.append(adv_image)
# adv_labels.append(target_label)
# texts.append(text)
# return adv_images, texts, adv_labels
def train_epoch(self):
self.change_state(mode="valid")
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0
times = 0
adv_images=[]
adv_labels=[]
texts=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
image.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
# image_anchor=self.image_representation(label.detach().cpu().numpy())
negetive_mean=torch.cat([self.text_mean[i] for i in label.detach().cpu().numpy()])
negative_var=torch.cat([self.text_var[i] for i in label.detach().cpu().numpy()])
negetive_code=self.model.encode_text(text)
target_label=label.flip(dims=[0])
positive_mean=negetive_mean.flip(dims=[0])
positive_var=negative_var.flip(dims=[0])
positive_code=self.model.encode_text(text.flip(dims=[0]))
delta=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
adv_image=delta+image
adv_images.append(adv_image)
adv_labels.append(target_label)
texts.append(text)
return adv_images, texts, adv_labels

View File

@ -10,6 +10,15 @@ def calc_hammingDist(B1, B2):
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
return distH
def calc_cosineDist(B1, B2):
q = B2.shape[1]
if len(B1.shape) < 2:
B1 = B1.unsqueeze(0)
dot_product = B1.mm(B2.transpose(0, 1))
norms = torch.sqrt(torch.einsum('ii->i', dot_product))
distH = dot_product/(norms[None]*norms[..., None])
return distH
def calc_map_k_matrix(qB, rB, query_L, retrieval_L, k=None, rank=0):