From 19bec9bfc238dd6334ac54ef0d62447b5fb46005 Mon Sep 17 00:00:00 2001 From: leewlving Date: Mon, 10 Jun 2024 22:01:05 +0800 Subject: [PATCH] debug --- main.py | 2 +- train/hash_train.py | 25 ++++++++++++++----------- utils/get_args.py | 2 +- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index ec46890..5eceeda 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from train.hash_train import Trainer if __name__ == "__main__": engine=Trainer() - engine.test() + # engine.test() adv_images, texts, adv_labels= engine.train_epoch() # engine.train() diff --git a/train/hash_train.py b/train/hash_train.py index 4c11cf9..8d56bbe 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -9,8 +9,7 @@ import scipy.io as scio import numpy as np from .base import TrainBase -from model.optimization import BertAdam -# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss +from torch.nn import functional as F from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils.calc_utils import calc_map_k_matrix as calc_map_k from dataset.dataloader import dataloader @@ -108,12 +107,12 @@ class Trainer(TrainBase): text_representation = {} text_var_representation = {} for i, centroid in enumerate(label_unipue): - text_representation[centroid.tobytes()] = text_centroids[i] - text_var_representation[centroid.tobytes()]= text_var[i] + text_representation[str(centroid.astype(int))] = text_centroids[i] + text_var_representation[str(centroid.astype(int))]= text_var[i] return text_representation, text_var_representation def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, - beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100): + beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05): delta = torch.zeros_like(image,requires_grad=True) # one=torch.zeros_like(positive) @@ -121,10 +120,12 @@ class Trainer(TrainBase): for i in range(num_iter): self.model.zero_grad() anchor=self.model.encode_image(image+delta) - loss1=alienation_loss(anchor, positive_code, negetive_code) - negative_dist=((anchor-negetive_mean))**2 / negative_var - positive_dist=((anchor-positive_mean))**2 /positive_var - loss= positive_dist -negative_dist + beta* loss1 + loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity()) + negative_dist=(anchor-negetive_mean)**2 / negative_var + positive_dist=(anchor-positive_mean)**2 /positive_var + negatives=torch.exp(negative_dist / temperature) + positives= torch.exp(positive_dist / temperature) + loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1 loss.backward(retain_graph=True) delta.data = delta - alpha * delta.grad.detach().sign() delta.data =clamp(delta, image).clamp(-epsilon, epsilon) @@ -148,8 +149,10 @@ class Trainer(TrainBase): text = text.to(self.rank, non_blocking=True) index = index.numpy() # image_anchor=self.image_representation(label.detach().cpu().numpy()) - negetive_mean=torch.cat([self.text_mean[i] for i in label.detach().cpu().numpy()]) - negative_var=torch.cat([self.text_var[i] for i in label.detach().cpu().numpy()]) + negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) + negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) + negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negetive_code=self.model.encode_text(text) target_label=label.flip(dims=[0]) positive_mean=negetive_mean.flip(dims=[0]) diff --git a/utils/get_args.py b/utils/get_args.py index 7216497..013e963 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--max-words", type=int, default=77) parser.add_argument("--resolution", type=int, default=224) - parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--query-num", type=int, default=5120) parser.add_argument("--train-num", type=int, default=10240)