debug
This commit is contained in:
parent
19bec9bfc2
commit
c655bb59e9
|
|
@ -35,7 +35,7 @@ def dataloader(captionFile: str,
|
|||
maxWords=77,
|
||||
imageResolution=224,
|
||||
query_num=5000,
|
||||
train_num=10000,
|
||||
train_num=1000,
|
||||
seed=None,
|
||||
npy=False):
|
||||
if captionFile.endswith("mat"):
|
||||
|
|
|
|||
5
main.py
5
main.py
|
|
@ -4,8 +4,9 @@ from train.hash_train import Trainer
|
|||
if __name__ == "__main__":
|
||||
|
||||
engine=Trainer()
|
||||
# engine.test()
|
||||
adv_images, texts, adv_labels= engine.train_epoch()
|
||||
engine.test()
|
||||
engine.train_epoch()
|
||||
|
||||
|
||||
# engine.train()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
from .base import TrainBase
|
||||
from torch.nn import functional as F
|
||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||
from utils.calc_utils import cal_map, cal_pr
|
||||
from dataset.dataloader import dataloader
|
||||
import open_clip
|
||||
# from transformers import BertModel
|
||||
|
|
@ -112,11 +112,11 @@ class Trainer(TrainBase):
|
|||
return text_representation, text_var_representation
|
||||
|
||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05):
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
# one=torch.zeros_like(positive)
|
||||
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
|
|
@ -130,25 +130,26 @@ class Trainer(TrainBase):
|
|||
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||
delta.grad.zero_()
|
||||
|
||||
return delta.detach()
|
||||
adv_code=self.model.encode_image(image+delta)
|
||||
return delta.detach() , adv_code
|
||||
|
||||
def train_epoch(self):
|
||||
self.change_state(mode="valid")
|
||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||
all_loss = 0
|
||||
times = 0
|
||||
adv_images=[]
|
||||
adv_labels=[]
|
||||
texts=[]
|
||||
adv_codes=[]
|
||||
adv_label=[]
|
||||
q_label=[]
|
||||
for image, text, label, index in self.train_loader:
|
||||
self.global_step += 1
|
||||
times += 1
|
||||
print(times)
|
||||
q_label.append(label.numpy())
|
||||
image.float()
|
||||
image = image.to(self.rank, non_blocking=True)
|
||||
text = text.to(self.rank, non_blocking=True)
|
||||
index = index.numpy()
|
||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||||
|
|
@ -158,13 +159,39 @@ class Trainer(TrainBase):
|
|||
positive_mean=negetive_mean.flip(dims=[0])
|
||||
positive_var=negative_var.flip(dims=[0])
|
||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
||||
delta=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||
positive_code,positive_mean,positive_var)
|
||||
adv_image=delta+image
|
||||
adv_images.append(adv_image)
|
||||
adv_labels.append(target_label)
|
||||
texts.append(text)
|
||||
return adv_images, texts, adv_labels
|
||||
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||
adv_label.append(target_label.numpy())
|
||||
adv_img=np.concatenate(adv_codes,axis=0)
|
||||
adv_labels=np.concatenate(adv_label, axis=0)
|
||||
query_labels=np.concatenate(q_label, axis=0)
|
||||
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
|
||||
|
||||
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
||||
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
|
||||
self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}")
|
||||
result_dict = {
|
||||
'adv_img': adv_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'adv_l': adv_labels,
|
||||
'r_l': retrieval_labels,
|
||||
'q_l':query_labels
|
||||
# 'pr': pr,
|
||||
# 'pr_t': pr_t
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
||||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -210,25 +237,7 @@ class Trainer(TrainBase):
|
|||
|
||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||
|
||||
def get_adv_code(self, adv_data_list,text_list):
|
||||
|
||||
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||
|
||||
for i in tqdm(range(len(adv_data_list))):
|
||||
image = adv_data_list[i].to(self.rank, non_blocking=True)
|
||||
text = text_list[i].to(self.rank, non_blocking=True)
|
||||
# index = index.numpy()
|
||||
image_hash=self.img_model(image)
|
||||
text_feat=self.bert(text)[0]
|
||||
text_hash=self.txt_model(text_feat)
|
||||
# text_hash = self.make_hash_code(text_hash)
|
||||
# image_hash.to(self.rank)
|
||||
# text_hash.to(self.rank)
|
||||
img_buffer[i, :] = image_hash.data
|
||||
text_buffer[i, :] = text_hash.data
|
||||
|
||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||
|
||||
|
||||
|
||||
def valid_attack(self,adv_images, texts, adv_labels):
|
||||
|
|
@ -239,20 +248,11 @@ class Trainer(TrainBase):
|
|||
|
||||
def test(self, mode_name="i2t"):
|
||||
self.logger.info("Valid Clean.")
|
||||
# if self.args.pretrained == "":
|
||||
# raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
||||
# self.change_state(mode="valid")
|
||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# print("map map")
|
||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
|
||||
|
||||
query_img = query_img.cpu().detach().numpy()
|
||||
query_txt = query_txt.cpu().detach().numpy()
|
||||
|
|
@ -260,7 +260,12 @@ class Trainer(TrainBase):
|
|||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
query_labels = self.query_labels.numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
||||
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
||||
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
||||
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
||||
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
||||
result_dict = {
|
||||
'q_img': query_img,
|
||||
'q_txt': query_txt,
|
||||
|
|
@ -268,32 +273,34 @@ class Trainer(TrainBase):
|
|||
'r_txt': retrieval_txt,
|
||||
'q_l': query_labels,
|
||||
'r_l': retrieval_labels
|
||||
# 'pr_i2t': pr_i2t,
|
||||
# 'pr_t2i': pr_t2i
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
||||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
def valid(self, epoch):
|
||||
self.logger.info("Valid.")
|
||||
self.change_state(mode="valid")
|
||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# print("get all code")
|
||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# print("map map")
|
||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
if self.max_mapi2t < mAPi2t:
|
||||
self.best_epoch_i = epoch
|
||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
if self.max_mapt2i < mAPt2i:
|
||||
self.best_epoch_t = epoch
|
||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
# def valid(self, epoch):
|
||||
# self.logger.info("Valid.")
|
||||
# self.change_state(mode="valid")
|
||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# # print("get all code")
|
||||
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# # print("map map")
|
||||
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# if self.max_mapi2t < mAPi2t:
|
||||
# self.best_epoch_i = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
# if self.max_mapt2i < mAPt2i:
|
||||
# self.best_epoch_t = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
|
||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def calc_cosineDist(B1, B2):
|
|||
B1 = B1.unsqueeze(0)
|
||||
dot_product = B1.mm(B2.transpose(0, 1))
|
||||
norms = torch.sqrt(torch.einsum('ii->i', dot_product))
|
||||
distH = dot_product/(norms[None]*norms[..., None])
|
||||
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
|
||||
return distH
|
||||
|
||||
|
||||
|
|
@ -174,66 +174,7 @@ def calc_precisions_hash(qB, rB, query_L, retrieval_L):
|
|||
recalls[i] += (recall_num / total_right)
|
||||
return precisions, recalls
|
||||
|
||||
def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval):
|
||||
if not isinstance(qB, torch.Tensor):
|
||||
qB = torch.from_numpy(qB)
|
||||
if not isinstance(rB, torch.Tensor):
|
||||
rB = torch.from_numpy(rB)
|
||||
if not isinstance(Gnd, torch.Tensor):
|
||||
Gnd = torch.from_numpy(Gnd)
|
||||
|
||||
def CalcHammingDist_np(B1, B2):
|
||||
q = B2.shape[1]
|
||||
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
|
||||
return distH
|
||||
bit = qB.shape[1]
|
||||
# if isinstance(qB, np.ndarray):
|
||||
# hamm = CalcHammingDist_np(qB, rB)
|
||||
# else:
|
||||
hamm = calc_hammingDist(qB, rB)
|
||||
hamm = hamm.type(torch.ByteTensor)
|
||||
total_num = [0] * (bit + 1)
|
||||
max_hamm = int(torch.max(hamm))
|
||||
|
||||
gnd = Gnd
|
||||
|
||||
total_right = torch.sum(gnd>0)
|
||||
precisions = np.zeros([max_hamm + 1])
|
||||
recalls = np.zeros([max_hamm + 1])
|
||||
|
||||
right_num = 0
|
||||
recall_num = 0
|
||||
for i, radius in enumerate(range(0, max_hamm+1)):
|
||||
recall = torch.nonzero(hamm == radius)
|
||||
right = gnd[recall.split(1, dim=1)]
|
||||
recall_num += recall.shape[0]
|
||||
del recall
|
||||
right_num += torch.nonzero(right).shape[0]
|
||||
del right
|
||||
precisions[i] += (right_num / (recall_num + 1e-8))
|
||||
recalls[i] += (recall_num / num_retrieval / num_query)
|
||||
# recalls[i] += (recall_num / total_right)
|
||||
p = precisions.round(2)
|
||||
r = recalls.round(2)
|
||||
# return p, r
|
||||
|
||||
precisions = []
|
||||
recalls = []
|
||||
|
||||
precision_ = 0
|
||||
num = 1
|
||||
for i in range(len(r) - 1):
|
||||
if r[i] == r[i + 1]:
|
||||
precision_ += p[i]
|
||||
num += 1
|
||||
else:
|
||||
precision_ += p[i]
|
||||
precisions.append(precision_ / num)
|
||||
recalls.append(r[i])
|
||||
precision_ = 0
|
||||
num = 1
|
||||
|
||||
return np.asarray(precisions).round(2), np.asarray(recalls)
|
||||
|
||||
|
||||
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
|
||||
|
|
@ -325,42 +266,101 @@ def calc_IF(all_bow):
|
|||
return IF
|
||||
|
||||
|
||||
# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta):
|
||||
# theta = torch.matmul(F, G.transpose(0, 1)) / 2
|
||||
# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta)
|
||||
# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2
|
||||
# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f)
|
||||
# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2
|
||||
# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g)
|
||||
# intra_loss = gamma1 * intra_img + gamma2 * intra_txt
|
||||
# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta
|
||||
# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2))
|
||||
# # loss = term1 + gamma * term2 + eta * term3
|
||||
# loss = inter_loss + intra_loss + quan_loss
|
||||
# return loss
|
||||
def cal_cosine_dis(f1, f2):
|
||||
f1_norm = np.linalg.norm(f1)
|
||||
f2_norm = np.linalg.norm(f2, axis=1)
|
||||
|
||||
similiarity = np.dot(f1, f2.T)/(f1_norm * f2_norm)
|
||||
return 1 - similiarity
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# qB = torch.Tensor([[1, -1, 1, 1],
|
||||
# [-1, -1, -1, 1],
|
||||
# [1, 1, -1, 1],
|
||||
# [1, 1, 1, -1]])
|
||||
# rB = torch.Tensor([[1, -1, 1, -1],
|
||||
# [-1, -1, 1, -1],
|
||||
# [-1, -1, 1, -1],
|
||||
# [1, 1, -1, -1],
|
||||
# [-1, 1, -1, -1],
|
||||
# [1, 1, -1, 1]])
|
||||
# query_L = torch.Tensor([[0, 1, 0, 0],
|
||||
# [1, 1, 0, 0],
|
||||
# [1, 0, 0, 1],
|
||||
# [0, 1, 0, 1]])
|
||||
# retrieval_L = torch.Tensor([[1, 0, 0, 1],
|
||||
# [1, 1, 0, 0],
|
||||
# [0, 1, 1, 0],
|
||||
# [0, 0, 1, 0],
|
||||
# [1, 0, 0, 0],
|
||||
# [0, 0, 1, 0]])
|
||||
#
|
||||
# map = calc_map_k(qB, rB, query_L, retrieval_L)
|
||||
# print(map)
|
||||
def cal_hamming_dis(b1, b2):
|
||||
k = b2.shape[1] # length of hash code
|
||||
dis = 0.5 * (k - np.dot(b1, b2.transpose()))
|
||||
return dis
|
||||
|
||||
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'):
|
||||
"""
|
||||
Calculate MAP (Mean Average Precision)
|
||||
:param query_binary: binary code of query sample
|
||||
:param query_label: label of qurey sample
|
||||
:param retrieval_binary: binary code of database
|
||||
:param retrieval_label: label of database
|
||||
:param top_k:
|
||||
:return:
|
||||
"""
|
||||
query_number = query_label.shape[0]
|
||||
top_k_map = 0
|
||||
|
||||
dist_func = cal_hamming_dis if dist_method == 'hamming' else cal_cosine_dis
|
||||
|
||||
for query_index in range(query_number):
|
||||
# (1, N)
|
||||
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(np.float32)
|
||||
hamming_dis = dist_func(query_feats[query_index, :], retrieval_feats) # (1, N)
|
||||
|
||||
# sort hamming distance
|
||||
sort_index = np.argsort(hamming_dis)
|
||||
|
||||
# resort ground truth
|
||||
ground_truth = ground_truth[sort_index]
|
||||
|
||||
# get top K ground truth
|
||||
top_k_gnd = ground_truth[0:top_k]
|
||||
top_k_sum = np.sum(top_k_gnd).astype(int) # the number of correct retrieval in top K
|
||||
if top_k_sum == 0:
|
||||
continue
|
||||
count = np.linspace(1, top_k_sum, int(top_k_sum))
|
||||
|
||||
top_k_index = np.asarray(np.where(top_k_gnd == 1)) + 1.0
|
||||
top_k_map += np.mean(count / top_k_index) # average precision of per class
|
||||
|
||||
return top_k_map / query_number # mean of average precision of all class
|
||||
|
||||
def cal_pr(retrieval_binary, query_binary, retrieval_label, query_label, interval=0.1):
|
||||
r_arr = np.array([i * interval for i in range(1, int(1/interval) + 1)])
|
||||
p_arr = np.zeros(len(r_arr))
|
||||
|
||||
query_number = query_label.shape[0]
|
||||
|
||||
for query_index in range(query_number):
|
||||
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
|
||||
np.float32) # (1, N)
|
||||
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
|
||||
|
||||
# sort hamming distance
|
||||
sort_index = np.argsort(hamming_dis)
|
||||
ground_truth = ground_truth[sort_index]
|
||||
tp_num = len(np.where(ground_truth == 1)[0])
|
||||
r_num_arr = (tp_num * r_arr).astype(np.int32)
|
||||
|
||||
tp_cum = np.cumsum(ground_truth)
|
||||
total_num_arr = np.array([np.where(tp_cum == i)[0][0] + 1 for i in r_num_arr])
|
||||
p_arr += r_num_arr/total_num_arr
|
||||
p_arr /= query_number
|
||||
|
||||
return np.array(list(zip(r_arr, p_arr)))
|
||||
|
||||
def cal_top_n(retrieval_binary, query_binary, retrieval_label, query_label, top_n=None):
|
||||
if top_n is None:
|
||||
top_n = range(100, 1001, 100)
|
||||
|
||||
top_n = np.array(top_n)
|
||||
top_n_p = np.zeros(len(top_n))
|
||||
query_number = query_label.shape[0]
|
||||
|
||||
for query_index in range(query_number):
|
||||
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
|
||||
np.float32) # (1, N)
|
||||
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
|
||||
|
||||
# sort hamming distance
|
||||
sort_index = np.argsort(hamming_dis)
|
||||
ground_truth = ground_truth[sort_index]
|
||||
ground_truth = ground_truth[:top_n[-1]]
|
||||
|
||||
tp_cum = np.cumsum(ground_truth)
|
||||
tp_num_arr = tp_cum[top_n - 1]
|
||||
top_n_p += tp_num_arr/top_n
|
||||
|
||||
top_n_p /= query_number
|
||||
return np.array(list(zip(top_n, top_n_p)))
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def get_args():
|
|||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--num-workers", type=int, default=4)
|
||||
parser.add_argument("--query-num", type=int, default=5120)
|
||||
parser.add_argument("--train-num", type=int, default=10240)
|
||||
parser.add_argument("--train-num", type=int, default=1024)
|
||||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||
parser.add_argument("--display-step", type=int, default=50)
|
||||
parser.add_argument("--seed", type=int, default=1814)
|
||||
|
|
|
|||
Loading…
Reference in New Issue