new
This commit is contained in:
parent
10df6ded0f
commit
1be09952fc
2
main.py
2
main.py
|
|
@ -4,7 +4,7 @@ from train.hash_train import Trainer
|
|||
if __name__ == "__main__":
|
||||
|
||||
engine=Trainer()
|
||||
# engine.test()
|
||||
engine.test()
|
||||
engine.train_epoch()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from tqdm import tqdm
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.utils.data as data
|
||||
import scipy.io as scio
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -86,7 +87,7 @@ class Trainer(TrainBase):
|
|||
pin_memory=True,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
self.train_data=train_data
|
||||
|
||||
|
||||
def generate_mapping(self):
|
||||
|
|
@ -140,32 +141,40 @@ class Trainer(TrainBase):
|
|||
times = 0
|
||||
adv_codes=[]
|
||||
adv_label=[]
|
||||
q_label=[]
|
||||
for image, text, label, index in self.train_loader:
|
||||
self.global_step += 1
|
||||
times += 1
|
||||
print(times)
|
||||
q_label.append(label.numpy())
|
||||
image.float()
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||||
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
|
||||
negetive_code=self.model.encode_text(text)
|
||||
target_label=label.flip(dims=[0])
|
||||
positive_mean=negetive_mean.flip(dims=[0])
|
||||
positive_var=negative_var.flip(dims=[0])
|
||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
||||
|
||||
#targeted sample
|
||||
np.random.seed(times)
|
||||
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
|
||||
target_dataset = data.Subset(self.train_data, select_index)
|
||||
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
|
||||
_, target_text, target_label, _ = next(iter(target_subset))
|
||||
target_text=target_text.to(self.rank, non_blocking=True)
|
||||
positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||||
positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
|
||||
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
|
||||
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
|
||||
positive_code=self.model.encode_text(target_text)
|
||||
|
||||
|
||||
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
positive_code,positive_mean,positive_var)
|
||||
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||
adv_label.append(target_label.numpy())
|
||||
adv_img=np.concatenate(adv_codes,axis=0)
|
||||
adv_labels=np.concatenate(adv_label, axis=0)
|
||||
query_labels=np.concatenate(q_label, axis=0)
|
||||
adv_img=np.concatenate(adv_codes)
|
||||
adv_labels=np.concatenate(adv_label)
|
||||
|
||||
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
|
||||
|
||||
|
|
@ -173,17 +182,17 @@ class Trainer(TrainBase):
|
|||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||
|
||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
|
||||
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
||||
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
|
||||
self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}")
|
||||
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
|
||||
result_dict = {
|
||||
'adv_img': adv_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'adv_l': adv_labels,
|
||||
'r_l': retrieval_labels,
|
||||
'q_l':query_labels
|
||||
'r_l': retrieval_labels
|
||||
# 'q_l':query_labels
|
||||
# 'pr': pr,
|
||||
# 'pr_t': pr_t
|
||||
}
|
||||
|
|
@ -230,8 +239,6 @@ class Trainer(TrainBase):
|
|||
with torch.no_grad():
|
||||
image_feature = self.model.encode_image(image)
|
||||
text_features = self.model.encode_text(text)
|
||||
image_feature /= image_feature.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
img_buffer[index, :] = image_feature.detach()
|
||||
text_buffer[index, :] = text_features.detach()
|
||||
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ def cal_hamming_dis(b1, b2):
|
|||
dis = 0.5 * (k - np.dot(b1, b2.transpose()))
|
||||
return dis
|
||||
|
||||
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'):
|
||||
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=5, dist_method='cosine'):
|
||||
"""
|
||||
Calculate MAP (Mean Average Precision)
|
||||
:param query_binary: binary code of query sample
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def get_args():
|
|||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--num-workers", type=int, default=4)
|
||||
parser.add_argument("--query-num", type=int, default=5120)
|
||||
parser.add_argument("--train-num", type=int, default=128)
|
||||
parser.add_argument("--train-num", type=int, default=1024)
|
||||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||
parser.add_argument("--display-step", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=1814)
|
||||
|
|
|
|||
Loading…
Reference in New Issue