This commit is contained in:
leewlving 2024-06-14 11:35:12 +08:00
parent 19bec9bfc2
commit c655bb59e9
5 changed files with 179 additions and 171 deletions

View File

@ -35,7 +35,7 @@ def dataloader(captionFile: str,
maxWords=77, maxWords=77,
imageResolution=224, imageResolution=224,
query_num=5000, query_num=5000,
train_num=10000, train_num=1000,
seed=None, seed=None,
npy=False): npy=False):
if captionFile.endswith("mat"): if captionFile.endswith("mat"):

View File

@ -4,8 +4,9 @@ from train.hash_train import Trainer
if __name__ == "__main__": if __name__ == "__main__":
engine=Trainer() engine=Trainer()
# engine.test() engine.test()
adv_images, texts, adv_labels= engine.train_epoch() engine.train_epoch()
# engine.train() # engine.train()

View File

@ -11,7 +11,7 @@ import numpy as np
from .base import TrainBase from .base import TrainBase
from torch.nn import functional as F from torch.nn import functional as F
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k from utils.calc_utils import cal_map, cal_pr
from dataset.dataloader import dataloader from dataset.dataloader import dataloader
import open_clip import open_clip
# from transformers import BertModel # from transformers import BertModel
@ -112,11 +112,11 @@ class Trainer(TrainBase):
return text_representation, text_var_representation return text_representation, text_var_representation
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var, def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05): beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
delta = torch.zeros_like(image,requires_grad=True) delta = torch.zeros_like(image,requires_grad=True)
# one=torch.zeros_like(positive) # one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) # alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter): for i in range(num_iter):
self.model.zero_grad() self.model.zero_grad()
anchor=self.model.encode_image(image+delta) anchor=self.model.encode_image(image+delta)
@ -130,25 +130,26 @@ class Trainer(TrainBase):
delta.data = delta - alpha * delta.grad.detach().sign() delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon) delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_() delta.grad.zero_()
adv_code=self.model.encode_image(image+delta)
return delta.detach() return delta.detach() , adv_code
def train_epoch(self): def train_epoch(self):
self.change_state(mode="valid") self.change_state(mode="valid")
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
all_loss = 0 all_loss = 0
times = 0 times = 0
adv_images=[] adv_codes=[]
adv_labels=[] adv_label=[]
texts=[] q_label=[]
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
self.global_step += 1 self.global_step += 1
times += 1 times += 1
print(times)
q_label.append(label.numpy())
image.float() image.float()
image = image.to(self.rank, non_blocking=True) image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True)
index = index.numpy() index = index.numpy()
# image_anchor=self.image_representation(label.detach().cpu().numpy())
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()]) negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True) negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
@ -158,13 +159,39 @@ class Trainer(TrainBase):
positive_mean=negetive_mean.flip(dims=[0]) positive_mean=negetive_mean.flip(dims=[0])
positive_var=negative_var.flip(dims=[0]) positive_var=negative_var.flip(dims=[0])
positive_code=self.model.encode_text(text.flip(dims=[0])) positive_code=self.model.encode_text(text.flip(dims=[0]))
delta=self.target_adv(image,negetive_code,negetive_mean,negative_var, delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
positive_code,positive_mean,positive_var) positive_code,positive_mean,positive_var)
adv_image=delta+image adv_codes.append(adv_code.cpu().detach().numpy())
adv_images.append(adv_image) adv_label.append(target_label.numpy())
adv_labels.append(target_label) adv_img=np.concatenate(adv_codes,axis=0)
texts.append(text) adv_labels=np.concatenate(adv_label, axis=0)
return adv_images, texts, adv_labels query_labels=np.concatenate(q_label, axis=0)
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
retrieval_txt = retrieval_txt.cpu().detach().numpy()
retrieval_labels = self.retrieval_labels.numpy()
mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}")
result_dict = {
'adv_img': adv_img,
'r_txt': retrieval_txt,
'adv_l': adv_labels,
'r_l': retrieval_labels,
'q_l':query_labels
# 'pr': pr,
# 'pr_t': pr_t
}
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!")
@ -210,25 +237,7 @@ class Trainer(TrainBase):
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def get_adv_code(self, adv_data_list,text_list):
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
for i in tqdm(range(len(adv_data_list))):
image = adv_data_list[i].to(self.rank, non_blocking=True)
text = text_list[i].to(self.rank, non_blocking=True)
# index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
# text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[i, :] = image_hash.data
text_buffer[i, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels): def valid_attack(self,adv_images, texts, adv_labels):
@ -239,20 +248,11 @@ class Trainer(TrainBase):
def test(self, mode_name="i2t"): def test(self, mode_name="i2t"):
self.logger.info("Valid Clean.") self.logger.info("Valid Clean.")
# if self.args.pretrained == "":
# raise RuntimeError("test step must load a model! please set the --pretrained argument.")
# self.change_state(mode="valid")
save_dir = os.path.join(self.args.save_dir, "PR_cruve") save_dir = os.path.join(self.args.save_dir, "PR_cruve")
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
query_img = query_img.cpu().detach().numpy() query_img = query_img.cpu().detach().numpy()
query_txt = query_txt.cpu().detach().numpy() query_txt = query_txt.cpu().detach().numpy()
@ -260,7 +260,12 @@ class Trainer(TrainBase):
retrieval_txt = retrieval_txt.cpu().detach().numpy() retrieval_txt = retrieval_txt.cpu().detach().numpy()
query_labels = self.query_labels.numpy() query_labels = self.query_labels.numpy()
retrieval_labels = self.retrieval_labels.numpy() retrieval_labels = self.retrieval_labels.numpy()
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
result_dict = { result_dict = {
'q_img': query_img, 'q_img': query_img,
'q_txt': query_txt, 'q_txt': query_txt,
@ -268,32 +273,34 @@ class Trainer(TrainBase):
'r_txt': retrieval_txt, 'r_txt': retrieval_txt,
'q_l': query_labels, 'q_l': query_labels,
'r_l': retrieval_labels 'r_l': retrieval_labels
# 'pr_i2t': pr_i2t,
# 'pr_t2i': pr_t2i
} }
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict) scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
self.logger.info(">>>>>> save all data!") self.logger.info(">>>>>> save all data!")
def valid(self, epoch): # def valid(self, epoch):
self.logger.info("Valid.") # self.logger.info("Valid.")
self.change_state(mode="valid") # self.change_state(mode="valid")
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num) # query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num) # retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
# print("get all code") # # print("get all code")
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
# print("map map") # # print("map map")
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) # mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank) # mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank) # mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
if self.max_mapi2t < mAPi2t: # if self.max_mapi2t < mAPi2t:
self.best_epoch_i = epoch # self.best_epoch_i = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t") # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
self.max_mapi2t = max(self.max_mapi2t, mAPi2t) # self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
if self.max_mapt2i < mAPt2i: # if self.max_mapt2i < mAPt2i:
self.best_epoch_t = epoch # self.best_epoch_t = epoch
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i") # self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
self.max_mapt2i = max(self.max_mapt2i, mAPt2i) # self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \ # self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}") # MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"): def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):

View File

@ -16,7 +16,7 @@ def calc_cosineDist(B1, B2):
B1 = B1.unsqueeze(0) B1 = B1.unsqueeze(0)
dot_product = B1.mm(B2.transpose(0, 1)) dot_product = B1.mm(B2.transpose(0, 1))
norms = torch.sqrt(torch.einsum('ii->i', dot_product)) norms = torch.sqrt(torch.einsum('ii->i', dot_product))
distH = dot_product/(norms[None]*norms[..., None]) distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
return distH return distH
@ -174,66 +174,7 @@ def calc_precisions_hash(qB, rB, query_L, retrieval_L):
recalls[i] += (recall_num / total_right) recalls[i] += (recall_num / total_right)
return precisions, recalls return precisions, recalls
def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval):
if not isinstance(qB, torch.Tensor):
qB = torch.from_numpy(qB)
if not isinstance(rB, torch.Tensor):
rB = torch.from_numpy(rB)
if not isinstance(Gnd, torch.Tensor):
Gnd = torch.from_numpy(Gnd)
def CalcHammingDist_np(B1, B2):
q = B2.shape[1]
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
return distH
bit = qB.shape[1]
# if isinstance(qB, np.ndarray):
# hamm = CalcHammingDist_np(qB, rB)
# else:
hamm = calc_hammingDist(qB, rB)
hamm = hamm.type(torch.ByteTensor)
total_num = [0] * (bit + 1)
max_hamm = int(torch.max(hamm))
gnd = Gnd
total_right = torch.sum(gnd>0)
precisions = np.zeros([max_hamm + 1])
recalls = np.zeros([max_hamm + 1])
right_num = 0
recall_num = 0
for i, radius in enumerate(range(0, max_hamm+1)):
recall = torch.nonzero(hamm == radius)
right = gnd[recall.split(1, dim=1)]
recall_num += recall.shape[0]
del recall
right_num += torch.nonzero(right).shape[0]
del right
precisions[i] += (right_num / (recall_num + 1e-8))
recalls[i] += (recall_num / num_retrieval / num_query)
# recalls[i] += (recall_num / total_right)
p = precisions.round(2)
r = recalls.round(2)
# return p, r
precisions = []
recalls = []
precision_ = 0
num = 1
for i in range(len(r) - 1):
if r[i] == r[i + 1]:
precision_ += p[i]
num += 1
else:
precision_ += p[i]
precisions.append(precision_ / num)
recalls.append(r[i])
precision_ = 0
num = 1
return np.asarray(precisions).round(2), np.asarray(recalls)
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1): def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
@ -325,42 +266,101 @@ def calc_IF(all_bow):
return IF return IF
# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta): def cal_cosine_dis(f1, f2):
# theta = torch.matmul(F, G.transpose(0, 1)) / 2 f1_norm = np.linalg.norm(f1)
# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta) f2_norm = np.linalg.norm(f2, axis=1)
# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2
# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f)
# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2
# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g)
# intra_loss = gamma1 * intra_img + gamma2 * intra_txt
# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta
# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2))
# # loss = term1 + gamma * term2 + eta * term3
# loss = inter_loss + intra_loss + quan_loss
# return loss
similiarity = np.dot(f1, f2.T)/(f1_norm * f2_norm)
return 1 - similiarity
# if __name__ == '__main__': def cal_hamming_dis(b1, b2):
# qB = torch.Tensor([[1, -1, 1, 1], k = b2.shape[1] # length of hash code
# [-1, -1, -1, 1], dis = 0.5 * (k - np.dot(b1, b2.transpose()))
# [1, 1, -1, 1], return dis
# [1, 1, 1, -1]])
# rB = torch.Tensor([[1, -1, 1, -1], def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'):
# [-1, -1, 1, -1], """
# [-1, -1, 1, -1], Calculate MAP (Mean Average Precision)
# [1, 1, -1, -1], :param query_binary: binary code of query sample
# [-1, 1, -1, -1], :param query_label: label of qurey sample
# [1, 1, -1, 1]]) :param retrieval_binary: binary code of database
# query_L = torch.Tensor([[0, 1, 0, 0], :param retrieval_label: label of database
# [1, 1, 0, 0], :param top_k:
# [1, 0, 0, 1], :return:
# [0, 1, 0, 1]]) """
# retrieval_L = torch.Tensor([[1, 0, 0, 1], query_number = query_label.shape[0]
# [1, 1, 0, 0], top_k_map = 0
# [0, 1, 1, 0],
# [0, 0, 1, 0], dist_func = cal_hamming_dis if dist_method == 'hamming' else cal_cosine_dis
# [1, 0, 0, 0],
# [0, 0, 1, 0]]) for query_index in range(query_number):
# # (1, N)
# map = calc_map_k(qB, rB, query_L, retrieval_L) ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(np.float32)
# print(map) hamming_dis = dist_func(query_feats[query_index, :], retrieval_feats) # (1, N)
# sort hamming distance
sort_index = np.argsort(hamming_dis)
# resort ground truth
ground_truth = ground_truth[sort_index]
# get top K ground truth
top_k_gnd = ground_truth[0:top_k]
top_k_sum = np.sum(top_k_gnd).astype(int) # the number of correct retrieval in top K
if top_k_sum == 0:
continue
count = np.linspace(1, top_k_sum, int(top_k_sum))
top_k_index = np.asarray(np.where(top_k_gnd == 1)) + 1.0
top_k_map += np.mean(count / top_k_index) # average precision of per class
return top_k_map / query_number # mean of average precision of all class
def cal_pr(retrieval_binary, query_binary, retrieval_label, query_label, interval=0.1):
r_arr = np.array([i * interval for i in range(1, int(1/interval) + 1)])
p_arr = np.zeros(len(r_arr))
query_number = query_label.shape[0]
for query_index in range(query_number):
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
np.float32) # (1, N)
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
# sort hamming distance
sort_index = np.argsort(hamming_dis)
ground_truth = ground_truth[sort_index]
tp_num = len(np.where(ground_truth == 1)[0])
r_num_arr = (tp_num * r_arr).astype(np.int32)
tp_cum = np.cumsum(ground_truth)
total_num_arr = np.array([np.where(tp_cum == i)[0][0] + 1 for i in r_num_arr])
p_arr += r_num_arr/total_num_arr
p_arr /= query_number
return np.array(list(zip(r_arr, p_arr)))
def cal_top_n(retrieval_binary, query_binary, retrieval_label, query_label, top_n=None):
if top_n is None:
top_n = range(100, 1001, 100)
top_n = np.array(top_n)
top_n_p = np.zeros(len(top_n))
query_number = query_label.shape[0]
for query_index in range(query_number):
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
np.float32) # (1, N)
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
# sort hamming distance
sort_index = np.argsort(hamming_dis)
ground_truth = ground_truth[sort_index]
ground_truth = ground_truth[:top_n[-1]]
tp_cum = np.cumsum(ground_truth)
tp_num_arr = tp_cum[top_n - 1]
top_n_p += tp_num_arr/top_n
top_n_p /= query_number
return np.array(list(zip(top_n, top_n_p)))

View File

@ -26,7 +26,7 @@ def get_args():
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--query-num", type=int, default=5120) parser.add_argument("--query-num", type=int, default=5120)
parser.add_argument("--train-num", type=int, default=10240) parser.add_argument("--train-num", type=int, default=1024)
parser.add_argument("--lr-decay-freq", type=int, default=5) parser.add_argument("--lr-decay-freq", type=int, default=5)
parser.add_argument("--display-step", type=int, default=50) parser.add_argument("--display-step", type=int, default=50)
parser.add_argument("--seed", type=int, default=1814) parser.add_argument("--seed", type=int, default=1814)