This commit is contained in:
leewlving 2024-06-17 16:36:01 +08:00
parent 10df6ded0f
commit 1be09952fc
4 changed files with 28 additions and 21 deletions

View File

@ -4,7 +4,7 @@ from train.hash_train import Trainer
if __name__ == "__main__":
engine=Trainer()
# engine.test()
engine.test()
engine.train_epoch()

View File

@ -5,6 +5,7 @@ from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.data as data
import scipy.io as scio
import numpy as np
@ -86,7 +87,7 @@ class Trainer(TrainBase):
pin_memory=True,
shuffle=True
)
self.train_data=train_data
def generate_mapping(self):
@ -140,32 +141,40 @@ class Trainer(TrainBase):
times = 0
adv_codes=[]
adv_label=[]
q_label=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
print(times)
q_label.append(label.numpy())
image.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True)
negetive_code=self.model.encode_text(text)
target_label=label.flip(dims=[0])
positive_mean=negetive_mean.flip(dims=[0])
positive_var=negative_var.flip(dims=[0])
positive_code=self.model.encode_text(text.flip(dims=[0]))
#targeted sample
np.random.seed(times)
select_index = np.random.choice(len(self.train_data), size=self.args.batch_size)
target_dataset = data.Subset(self.train_data, select_index)
target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size)
_, target_text, target_label, _ = next(iter(target_subset))
target_text=target_text.to(self.rank, non_blocking=True)
positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()])
positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True)
positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True)
positive_code=self.model.encode_text(target_text)
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var)
adv_codes.append(adv_code.cpu().detach().numpy())
adv_label.append(target_label.numpy())
adv_img=np.concatenate(adv_codes,axis=0)
adv_labels=np.concatenate(adv_label, axis=0)
query_labels=np.concatenate(q_label, axis=0)
adv_img=np.concatenate(adv_codes)
adv_labels=np.concatenate(adv_label)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
@ -173,17 +182,17 @@ class Trainer(TrainBase):
retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels)
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}")
self.logger.info(f">>>>>> MAP_t: {mAP_t}")
result_dict = {
'adv_img': adv_img,
'r_txt': retrieval_txt,
'adv_l': adv_labels,
'r_l': retrieval_labels,
'q_l':query_labels
'r_l': retrieval_labels
# 'q_l':query_labels
# 'pr': pr,
# 'pr_t': pr_t
}
@ -230,8 +239,6 @@ class Trainer(TrainBase):
with torch.no_grad():
image_feature = self.model.encode_image(image)
text_features = self.model.encode_text(text)
image_feature /= image_feature.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
img_buffer[index, :] = image_feature.detach()
text_buffer[index, :] = text_features.detach()

View File

@ -278,7 +278,7 @@ def cal_hamming_dis(b1, b2):
dis = 0.5 * (k - np.dot(b1, b2.transpose()))
return dis
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'):
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=5, dist_method='cosine'):
"""
Calculate MAP (Mean Average Precision)
:param query_binary: binary code of query sample

View File

@ -26,7 +26,7 @@ def get_args():
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--query-num", type=int, default=5120)
parser.add_argument("--train-num", type=int, default=128)
parser.add_argument("--train-num", type=int, default=1024)
parser.add_argument("--lr-decay-freq", type=int, default=5)
parser.add_argument("--display-step", type=int, default=50)
parser.add_argument("--seed", type=int, default=1814)