debug
This commit is contained in:
parent
b482a0e942
commit
19bec9bfc2
2
main.py
2
main.py
|
|
@ -4,7 +4,7 @@ from train.hash_train import Trainer
|
|||
if __name__ == "__main__":
|
||||
|
||||
engine=Trainer()
|
||||
engine.test()
|
||||
# engine.test()
|
||||
adv_images, texts, adv_labels= engine.train_epoch()
|
||||
|
||||
# engine.train()
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@ import scipy.io as scio
|
|||
import numpy as np
|
||||
|
||||
from .base import TrainBase
|
||||
from model.optimization import BertAdam
|
||||
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
|
||||
from torch.nn import functional as F
|
||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||
from dataset.dataloader import dataloader
|
||||
|
|
@ -108,12 +107,12 @@ class Trainer(TrainBase):
|
|||
text_representation = {}
|
||||
text_var_representation = {}
|
||||
for i, centroid in enumerate(label_unipue):
|
||||
text_representation[centroid.tobytes()] = text_centroids[i]
|
||||
text_var_representation[centroid.tobytes()]= text_var[i]
|
||||
text_representation[str(centroid.astype(int))] = text_centroids[i]
|
||||
text_var_representation[str(centroid.astype(int))]= text_var[i]
|
||||
return text_representation, text_var_representation
|
||||
|
||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
# one=torch.zeros_like(positive)
|
||||
|
|
@ -121,10 +120,12 @@ class Trainer(TrainBase):
|
|||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
loss1=alienation_loss(anchor, positive_code, negetive_code)
|
||||
negative_dist=((anchor-negetive_mean))**2 / negative_var
|
||||
positive_dist=((anchor-positive_mean))**2 /positive_var
|
||||
loss= positive_dist -negative_dist + beta* loss1
|
||||
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
|
||||
negative_dist=(anchor-negetive_mean)**2 / negative_var
|
||||
positive_dist=(anchor-positive_mean)**2 /positive_var
|
||||
negatives=torch.exp(negative_dist / temperature)
|
||||
positives= torch.exp(positive_dist / temperature)
|
||||
loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1
|
||||
loss.backward(retain_graph=True)
|
||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||
|
|
@ -148,8 +149,10 @@ class Trainer(TrainBase):
|
|||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
negetive_mean=torch.cat([self.text_mean[i] for i in label.detach().cpu().numpy()])
|
||||
negative_var=torch.cat([self.text_var[i] for i in label.detach().cpu().numpy()])
|
||||
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||||
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
||||
negetive_code=self.model.encode_text(text)
|
||||
target_label=label.flip(dims=[0])
|
||||
positive_mean=negetive_mean.flip(dims=[0])
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def get_args():
|
|||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--max-words", type=int, default=77)
|
||||
parser.add_argument("--resolution", type=int, default=224)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--num-workers", type=int, default=4)
|
||||
parser.add_argument("--query-num", type=int, default=5120)
|
||||
parser.add_argument("--train-num", type=int, default=10240)
|
||||
|
|
|
|||
Loading…
Reference in New Issue