From 1be09952fc1dd78d2c3c53ce4932ec2f64bd1931 Mon Sep 17 00:00:00 2001 From: leewlving Date: Mon, 17 Jun 2024 16:36:01 +0800 Subject: [PATCH] new --- main.py | 2 +- train/hash_train.py | 43 +++++++++++++++++++++++++------------------ utils/calc_utils.py | 2 +- utils/get_args.py | 2 +- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index 36da785..e3c978e 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from train.hash_train import Trainer if __name__ == "__main__": engine=Trainer() - # engine.test() + engine.test() engine.train_epoch() diff --git a/train/hash_train.py b/train/hash_train.py index 07e62df..7f5b82e 100644 --- a/train/hash_train.py +++ b/train/hash_train.py @@ -5,6 +5,7 @@ from tqdm import tqdm import torch import torch.nn as nn from torch.utils.data import DataLoader +import torch.utils.data as data import scipy.io as scio import numpy as np @@ -86,7 +87,7 @@ class Trainer(TrainBase): pin_memory=True, shuffle=True ) - + self.train_data=train_data def generate_mapping(self): @@ -140,32 +141,40 @@ class Trainer(TrainBase): times = 0 adv_codes=[] adv_label=[] - q_label=[] for image, text, label, index in self.train_loader: self.global_step += 1 times += 1 print(times) - q_label.append(label.numpy()) image.float() image = image.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True) - index = index.numpy() negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negative_var=torch.from_numpy(negative_var).to(self.rank, non_blocking=True) negetive_code=self.model.encode_text(text) - target_label=label.flip(dims=[0]) - positive_mean=negetive_mean.flip(dims=[0]) - positive_var=negative_var.flip(dims=[0]) - positive_code=self.model.encode_text(text.flip(dims=[0])) + + #targeted sample + np.random.seed(times) + select_index = np.random.choice(len(self.train_data), size=self.args.batch_size) + target_dataset = data.Subset(self.train_data, select_index) + target_subset = torch.utils.data.DataLoader(target_dataset, batch_size=self.args.batch_size) + _, target_text, target_label, _ = next(iter(target_subset)) + target_text=target_text.to(self.rank, non_blocking=True) + positive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_var=np.stack([self.text_var[str(i.astype(int))] for i in target_label.detach().cpu().numpy()]) + positive_mean=torch.from_numpy(positive_mean).to(self.rank, non_blocking=True) + positive_var=torch.from_numpy(positive_var).to(self.rank, non_blocking=True) + positive_code=self.model.encode_text(target_text) + + delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var) adv_codes.append(adv_code.cpu().detach().numpy()) adv_label.append(target_label.numpy()) - adv_img=np.concatenate(adv_codes,axis=0) - adv_labels=np.concatenate(adv_label, axis=0) - query_labels=np.concatenate(q_label, axis=0) + adv_img=np.concatenate(adv_codes) + adv_labels=np.concatenate(adv_label) + _, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) @@ -173,17 +182,17 @@ class Trainer(TrainBase): retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_labels = self.retrieval_labels.numpy() - mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine') - mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine') + + mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels) # pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels) # pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels) - self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}") + self.logger.info(f">>>>>> MAP_t: {mAP_t}") result_dict = { 'adv_img': adv_img, 'r_txt': retrieval_txt, 'adv_l': adv_labels, - 'r_l': retrieval_labels, - 'q_l':query_labels + 'r_l': retrieval_labels + # 'q_l':query_labels # 'pr': pr, # 'pr_t': pr_t } @@ -230,8 +239,6 @@ class Trainer(TrainBase): with torch.no_grad(): image_feature = self.model.encode_image(image) text_features = self.model.encode_text(text) - image_feature /= image_feature.norm(dim=-1, keepdim=True) - text_features /= text_features.norm(dim=-1, keepdim=True) img_buffer[index, :] = image_feature.detach() text_buffer[index, :] = text_features.detach() diff --git a/utils/calc_utils.py b/utils/calc_utils.py index a27a169..c8cece1 100644 --- a/utils/calc_utils.py +++ b/utils/calc_utils.py @@ -278,7 +278,7 @@ def cal_hamming_dis(b1, b2): dis = 0.5 * (k - np.dot(b1, b2.transpose())) return dis -def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'): +def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=5, dist_method='cosine'): """ Calculate MAP (Mean Average Precision) :param query_binary: binary code of query sample diff --git a/utils/get_args.py b/utils/get_args.py index 99e24ac..42b6a22 100644 --- a/utils/get_args.py +++ b/utils/get_args.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--query-num", type=int, default=5120) - parser.add_argument("--train-num", type=int, default=128) + parser.add_argument("--train-num", type=int, default=1024) parser.add_argument("--lr-decay-freq", type=int, default=5) parser.add_argument("--display-step", type=int, default=50) parser.add_argument("--seed", type=int, default=1814)