This commit is contained in:
leewlving 2024-06-10 22:01:05 +08:00
parent b482a0e942
commit 19bec9bfc2
3 changed files with 16 additions and 13 deletions

View File

@ -4,7 +4,7 @@ from train.hash_train import Trainer
if __name__ == "__main__":
engine=Trainer()
engine.test()
# engine.test()
adv_images, texts, adv_labels= engine.train_epoch()
# engine.train()

View File

@ -9,8 +9,7 @@ import scipy.io as scio
import numpy as np
from .base import TrainBase
from model.optimization import BertAdam
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader
@ -108,12 +107,12 @@ class Trainer(TrainBase):
text_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[centroid.tobytes()] = text_centroids[i]
text_var_representation[centroid.tobytes()]= text_var[i]
text_representation[str(centroid.astype(int))] = text_centroids[i]
text_var_representation[str(centroid.astype(int))]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100):
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05):
delta = torch.zeros_like(image,requires_grad=True)
# one=torch.zeros_like(positive)
@ -121,10 +120,12 @@ class Trainer(TrainBase):
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss1=alienation_loss(anchor, positive_code, negetive_code)
negative_dist=((anchor-negetive_mean))**2 / negative_var
positive_dist=((anchor-positive_mean))**2 /positive_var
loss= positive_dist -negative_dist + beta* loss1
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
negative_dist=(anchor-negetive_mean)**2 / negative_var
positive_dist=(anchor-positive_mean)**2 /positive_var
negatives=torch.exp(negative_dist / temperature)
positives= torch.exp(positive_dist / temperature)
loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
@ -148,8 +149,10 @@ class Trainer(TrainBase):
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
# image_anchor=self.image_representation(label.detach().cpu().numpy())
negetive_mean=torch.cat([self.text_mean[i] for i in label.detach().cpu().numpy()])
negative_var=torch.cat([self.text_var[i] for i in label.detach().cpu().numpy()])
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_text(text)
target_label=label.flip(dims=[0])
positive_mean=negetive_mean.flip(dims=[0])

View File

@ -23,7 +23,7 @@ def get_args():
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=77)
parser.add_argument("--resolution", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--query-num", type=int, default=5120)
parser.add_argument("--train-num", type=int, default=10240)