debug
This commit is contained in:
parent
19bec9bfc2
commit
c655bb59e9
|
|
@ -35,7 +35,7 @@ def dataloader(captionFile: str,
|
||||||
maxWords=77,
|
maxWords=77,
|
||||||
imageResolution=224,
|
imageResolution=224,
|
||||||
query_num=5000,
|
query_num=5000,
|
||||||
train_num=10000,
|
train_num=1000,
|
||||||
seed=None,
|
seed=None,
|
||||||
npy=False):
|
npy=False):
|
||||||
if captionFile.endswith("mat"):
|
if captionFile.endswith("mat"):
|
||||||
|
|
|
||||||
5
main.py
5
main.py
|
|
@ -4,8 +4,9 @@ from train.hash_train import Trainer
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
engine=Trainer()
|
engine=Trainer()
|
||||||
# engine.test()
|
engine.test()
|
||||||
adv_images, texts, adv_labels= engine.train_epoch()
|
engine.train_epoch()
|
||||||
|
|
||||||
|
|
||||||
# engine.train()
|
# engine.train()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import numpy as np
|
||||||
from .base import TrainBase
|
from .base import TrainBase
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
from utils.calc_utils import cal_map, cal_pr
|
||||||
from dataset.dataloader import dataloader
|
from dataset.dataloader import dataloader
|
||||||
import open_clip
|
import open_clip
|
||||||
# from transformers import BertModel
|
# from transformers import BertModel
|
||||||
|
|
@ -112,11 +112,11 @@ class Trainer(TrainBase):
|
||||||
return text_representation, text_var_representation
|
return text_representation, text_var_representation
|
||||||
|
|
||||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=100, temperature=0.05):
|
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||||
|
|
||||||
delta = torch.zeros_like(image,requires_grad=True)
|
delta = torch.zeros_like(image,requires_grad=True)
|
||||||
# one=torch.zeros_like(positive)
|
# one=torch.zeros_like(positive)
|
||||||
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||||
for i in range(num_iter):
|
for i in range(num_iter):
|
||||||
self.model.zero_grad()
|
self.model.zero_grad()
|
||||||
anchor=self.model.encode_image(image+delta)
|
anchor=self.model.encode_image(image+delta)
|
||||||
|
|
@ -130,25 +130,26 @@ class Trainer(TrainBase):
|
||||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||||
delta.grad.zero_()
|
delta.grad.zero_()
|
||||||
|
adv_code=self.model.encode_image(image+delta)
|
||||||
return delta.detach()
|
return delta.detach() , adv_code
|
||||||
|
|
||||||
def train_epoch(self):
|
def train_epoch(self):
|
||||||
self.change_state(mode="valid")
|
self.change_state(mode="valid")
|
||||||
# self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||||
all_loss = 0
|
all_loss = 0
|
||||||
times = 0
|
times = 0
|
||||||
adv_images=[]
|
adv_codes=[]
|
||||||
adv_labels=[]
|
adv_label=[]
|
||||||
texts=[]
|
q_label=[]
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
times += 1
|
times += 1
|
||||||
|
print(times)
|
||||||
|
q_label.append(label.numpy())
|
||||||
image.float()
|
image.float()
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
index = index.numpy()
|
index = index.numpy()
|
||||||
# image_anchor=self.image_representation(label.detach().cpu().numpy())
|
|
||||||
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
negetive_mean=np.stack([self.text_mean[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||||
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
negative_var=np.stack([self.text_var[str(i.astype(int))] for i in label.detach().cpu().numpy()])
|
||||||
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
negetive_mean=torch.from_numpy(negetive_mean).to(self.rank, non_blocking=True)
|
||||||
|
|
@ -158,13 +159,39 @@ class Trainer(TrainBase):
|
||||||
positive_mean=negetive_mean.flip(dims=[0])
|
positive_mean=negetive_mean.flip(dims=[0])
|
||||||
positive_var=negative_var.flip(dims=[0])
|
positive_var=negative_var.flip(dims=[0])
|
||||||
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
positive_code=self.model.encode_text(text.flip(dims=[0]))
|
||||||
delta=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
delta, adv_code=self.target_adv(image,negetive_code,negetive_mean,negative_var,
|
||||||
positive_code,positive_mean,positive_var)
|
positive_code,positive_mean,positive_var)
|
||||||
adv_image=delta+image
|
adv_codes.append(adv_code.cpu().detach().numpy())
|
||||||
adv_images.append(adv_image)
|
adv_label.append(target_label.numpy())
|
||||||
adv_labels.append(target_label)
|
adv_img=np.concatenate(adv_codes,axis=0)
|
||||||
texts.append(text)
|
adv_labels=np.concatenate(adv_label, axis=0)
|
||||||
return adv_images, texts, adv_labels
|
query_labels=np.concatenate(q_label, axis=0)
|
||||||
|
_, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
|
|
||||||
|
mAP = cal_map(adv_img,query_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||||
|
mAP_t=cal_map(adv_img,adv_labels,retrieval_txt,retrieval_labels,dist_method='cosine')
|
||||||
|
# pr=cal_pr(retrieval_txt,adv_img,query_labels,retrieval_labels)
|
||||||
|
# pr_t=cal_pr(retrieval_txt,adv_img,adv_labels,retrieval_labels)
|
||||||
|
self.logger.info(f">>>>>> MAP: {mAP}, MAP_t: {mAP_t}")
|
||||||
|
result_dict = {
|
||||||
|
'adv_img': adv_img,
|
||||||
|
'r_txt': retrieval_txt,
|
||||||
|
'adv_l': adv_labels,
|
||||||
|
'r_l': retrieval_labels,
|
||||||
|
'q_l':query_labels
|
||||||
|
# 'pr': pr,
|
||||||
|
# 'pr_t': pr_t
|
||||||
|
}
|
||||||
|
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-adv-" + self.args.dataset + ".mat"), result_dict)
|
||||||
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -210,25 +237,7 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
|
|
||||||
def get_adv_code(self, adv_data_list,text_list):
|
|
||||||
|
|
||||||
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
||||||
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
|
||||||
|
|
||||||
for i in tqdm(range(len(adv_data_list))):
|
|
||||||
image = adv_data_list[i].to(self.rank, non_blocking=True)
|
|
||||||
text = text_list[i].to(self.rank, non_blocking=True)
|
|
||||||
# index = index.numpy()
|
|
||||||
image_hash=self.img_model(image)
|
|
||||||
text_feat=self.bert(text)[0]
|
|
||||||
text_hash=self.txt_model(text_feat)
|
|
||||||
# text_hash = self.make_hash_code(text_hash)
|
|
||||||
# image_hash.to(self.rank)
|
|
||||||
# text_hash.to(self.rank)
|
|
||||||
img_buffer[i, :] = image_hash.data
|
|
||||||
text_buffer[i, :] = text_hash.data
|
|
||||||
|
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
|
||||||
|
|
||||||
|
|
||||||
def valid_attack(self,adv_images, texts, adv_labels):
|
def valid_attack(self,adv_images, texts, adv_labels):
|
||||||
|
|
@ -239,20 +248,11 @@ class Trainer(TrainBase):
|
||||||
|
|
||||||
def test(self, mode_name="i2t"):
|
def test(self, mode_name="i2t"):
|
||||||
self.logger.info("Valid Clean.")
|
self.logger.info("Valid Clean.")
|
||||||
# if self.args.pretrained == "":
|
|
||||||
# raise RuntimeError("test step must load a model! please set the --pretrained argument.")
|
|
||||||
# self.change_state(mode="valid")
|
|
||||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num)
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
# print("map map")
|
|
||||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
|
||||||
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}")
|
|
||||||
|
|
||||||
query_img = query_img.cpu().detach().numpy()
|
query_img = query_img.cpu().detach().numpy()
|
||||||
query_txt = query_txt.cpu().detach().numpy()
|
query_txt = query_txt.cpu().detach().numpy()
|
||||||
|
|
@ -260,7 +260,12 @@ class Trainer(TrainBase):
|
||||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||||
query_labels = self.query_labels.numpy()
|
query_labels = self.query_labels.numpy()
|
||||||
retrieval_labels = self.retrieval_labels.numpy()
|
retrieval_labels = self.retrieval_labels.numpy()
|
||||||
|
mAPi2t = cal_map(query_img,query_labels,retrieval_txt,retrieval_labels)
|
||||||
|
mAPt2i =cal_map(query_txt,query_labels,retrieval_img,retrieval_labels)
|
||||||
|
# pr_i2t=cal_pr(retrieval_txt,query_img,query_labels,retrieval_labels)
|
||||||
|
# pr_t2i=cal_pr(retrieval_img,query_txt,query_labels,retrieval_labels)
|
||||||
|
self.max_mapt2i = max(self.max_mapt2i, mAPi2t)
|
||||||
|
self.logger.info(f">>>>>> MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}")
|
||||||
result_dict = {
|
result_dict = {
|
||||||
'q_img': query_img,
|
'q_img': query_img,
|
||||||
'q_txt': query_txt,
|
'q_txt': query_txt,
|
||||||
|
|
@ -268,32 +273,34 @@ class Trainer(TrainBase):
|
||||||
'r_txt': retrieval_txt,
|
'r_txt': retrieval_txt,
|
||||||
'q_l': query_labels,
|
'q_l': query_labels,
|
||||||
'r_l': retrieval_labels
|
'r_l': retrieval_labels
|
||||||
|
# 'pr_i2t': pr_i2t,
|
||||||
|
# 'pr_t2i': pr_t2i
|
||||||
}
|
}
|
||||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + ".mat"), result_dict)
|
||||||
self.logger.info(">>>>>> save all data!")
|
self.logger.info(">>>>>> save all data!")
|
||||||
|
|
||||||
|
|
||||||
def valid(self, epoch):
|
# def valid(self, epoch):
|
||||||
self.logger.info("Valid.")
|
# self.logger.info("Valid.")
|
||||||
self.change_state(mode="valid")
|
# self.change_state(mode="valid")
|
||||||
query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||||
retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||||
# print("get all code")
|
# # print("get all code")
|
||||||
mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
# print("map map")
|
# # print("map map")
|
||||||
mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||||
if self.max_mapi2t < mAPi2t:
|
# if self.max_mapi2t < mAPi2t:
|
||||||
self.best_epoch_i = epoch
|
# self.best_epoch_i = epoch
|
||||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||||
self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||||
if self.max_mapt2i < mAPt2i:
|
# if self.max_mapt2i < mAPt2i:
|
||||||
self.best_epoch_t = epoch
|
# self.best_epoch_t = epoch
|
||||||
self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||||
self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||||
MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||||
|
|
||||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ def calc_cosineDist(B1, B2):
|
||||||
B1 = B1.unsqueeze(0)
|
B1 = B1.unsqueeze(0)
|
||||||
dot_product = B1.mm(B2.transpose(0, 1))
|
dot_product = B1.mm(B2.transpose(0, 1))
|
||||||
norms = torch.sqrt(torch.einsum('ii->i', dot_product))
|
norms = torch.sqrt(torch.einsum('ii->i', dot_product))
|
||||||
distH = dot_product/(norms[None]*norms[..., None])
|
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
|
||||||
return distH
|
return distH
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -174,66 +174,7 @@ def calc_precisions_hash(qB, rB, query_L, retrieval_L):
|
||||||
recalls[i] += (recall_num / total_right)
|
recalls[i] += (recall_num / total_right)
|
||||||
return precisions, recalls
|
return precisions, recalls
|
||||||
|
|
||||||
def calc_precisions_hash_my(qB, rB, *, Gnd, num_query, num_retrieval):
|
|
||||||
if not isinstance(qB, torch.Tensor):
|
|
||||||
qB = torch.from_numpy(qB)
|
|
||||||
if not isinstance(rB, torch.Tensor):
|
|
||||||
rB = torch.from_numpy(rB)
|
|
||||||
if not isinstance(Gnd, torch.Tensor):
|
|
||||||
Gnd = torch.from_numpy(Gnd)
|
|
||||||
|
|
||||||
def CalcHammingDist_np(B1, B2):
|
|
||||||
q = B2.shape[1]
|
|
||||||
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
|
|
||||||
return distH
|
|
||||||
bit = qB.shape[1]
|
|
||||||
# if isinstance(qB, np.ndarray):
|
|
||||||
# hamm = CalcHammingDist_np(qB, rB)
|
|
||||||
# else:
|
|
||||||
hamm = calc_hammingDist(qB, rB)
|
|
||||||
hamm = hamm.type(torch.ByteTensor)
|
|
||||||
total_num = [0] * (bit + 1)
|
|
||||||
max_hamm = int(torch.max(hamm))
|
|
||||||
|
|
||||||
gnd = Gnd
|
|
||||||
|
|
||||||
total_right = torch.sum(gnd>0)
|
|
||||||
precisions = np.zeros([max_hamm + 1])
|
|
||||||
recalls = np.zeros([max_hamm + 1])
|
|
||||||
|
|
||||||
right_num = 0
|
|
||||||
recall_num = 0
|
|
||||||
for i, radius in enumerate(range(0, max_hamm+1)):
|
|
||||||
recall = torch.nonzero(hamm == radius)
|
|
||||||
right = gnd[recall.split(1, dim=1)]
|
|
||||||
recall_num += recall.shape[0]
|
|
||||||
del recall
|
|
||||||
right_num += torch.nonzero(right).shape[0]
|
|
||||||
del right
|
|
||||||
precisions[i] += (right_num / (recall_num + 1e-8))
|
|
||||||
recalls[i] += (recall_num / num_retrieval / num_query)
|
|
||||||
# recalls[i] += (recall_num / total_right)
|
|
||||||
p = precisions.round(2)
|
|
||||||
r = recalls.round(2)
|
|
||||||
# return p, r
|
|
||||||
|
|
||||||
precisions = []
|
|
||||||
recalls = []
|
|
||||||
|
|
||||||
precision_ = 0
|
|
||||||
num = 1
|
|
||||||
for i in range(len(r) - 1):
|
|
||||||
if r[i] == r[i + 1]:
|
|
||||||
precision_ += p[i]
|
|
||||||
num += 1
|
|
||||||
else:
|
|
||||||
precision_ += p[i]
|
|
||||||
precisions.append(precision_ / num)
|
|
||||||
recalls.append(r[i])
|
|
||||||
precision_ = 0
|
|
||||||
num = 1
|
|
||||||
|
|
||||||
return np.asarray(precisions).round(2), np.asarray(recalls)
|
|
||||||
|
|
||||||
|
|
||||||
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
|
def calc_precisions_hamming_radius(qB, rB, query_L, retrieval_L, hamming_gas=1):
|
||||||
|
|
@ -325,42 +266,101 @@ def calc_IF(all_bow):
|
||||||
return IF
|
return IF
|
||||||
|
|
||||||
|
|
||||||
# def calc_loss(B, F, G, Sim, gamma1, gamma2, eta):
|
def cal_cosine_dis(f1, f2):
|
||||||
# theta = torch.matmul(F, G.transpose(0, 1)) / 2
|
f1_norm = np.linalg.norm(f1)
|
||||||
# inter_loss = torch.sum(torch.log(1 + torch.exp(theta)) - Sim * theta)
|
f2_norm = np.linalg.norm(f2, axis=1)
|
||||||
# theta_f = torch.matmul(F, F.transpose(0, 1)) / 2
|
|
||||||
# intra_img = torch.sum(torch.log(1 + torch.exp(theta_f)) - Sim * theta_f)
|
|
||||||
# theta_g = torch.matmul(G, G.transpose(0, 1)) / 2
|
|
||||||
# intra_txt = torch.sum(torch.log(1 + torch.exp(theta_g)) - Sim * theta_g)
|
|
||||||
# intra_loss = gamma1 * intra_img + gamma2 * intra_txt
|
|
||||||
# quan_loss = torch.sum(torch.pow(B - F, 2) + torch.pow(B - G, 2)) * eta
|
|
||||||
# # term3 = torch.sum(torch.pow(F.sum(dim=0), 2) + torch.pow(G.sum(dim=0), 2))
|
|
||||||
# # loss = term1 + gamma * term2 + eta * term3
|
|
||||||
# loss = inter_loss + intra_loss + quan_loss
|
|
||||||
# return loss
|
|
||||||
|
|
||||||
|
similiarity = np.dot(f1, f2.T)/(f1_norm * f2_norm)
|
||||||
|
return 1 - similiarity
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
def cal_hamming_dis(b1, b2):
|
||||||
# qB = torch.Tensor([[1, -1, 1, 1],
|
k = b2.shape[1] # length of hash code
|
||||||
# [-1, -1, -1, 1],
|
dis = 0.5 * (k - np.dot(b1, b2.transpose()))
|
||||||
# [1, 1, -1, 1],
|
return dis
|
||||||
# [1, 1, 1, -1]])
|
|
||||||
# rB = torch.Tensor([[1, -1, 1, -1],
|
def cal_map(query_feats, query_label, retrieval_feats, retrieval_label, top_k=500, dist_method='hamming'):
|
||||||
# [-1, -1, 1, -1],
|
"""
|
||||||
# [-1, -1, 1, -1],
|
Calculate MAP (Mean Average Precision)
|
||||||
# [1, 1, -1, -1],
|
:param query_binary: binary code of query sample
|
||||||
# [-1, 1, -1, -1],
|
:param query_label: label of qurey sample
|
||||||
# [1, 1, -1, 1]])
|
:param retrieval_binary: binary code of database
|
||||||
# query_L = torch.Tensor([[0, 1, 0, 0],
|
:param retrieval_label: label of database
|
||||||
# [1, 1, 0, 0],
|
:param top_k:
|
||||||
# [1, 0, 0, 1],
|
:return:
|
||||||
# [0, 1, 0, 1]])
|
"""
|
||||||
# retrieval_L = torch.Tensor([[1, 0, 0, 1],
|
query_number = query_label.shape[0]
|
||||||
# [1, 1, 0, 0],
|
top_k_map = 0
|
||||||
# [0, 1, 1, 0],
|
|
||||||
# [0, 0, 1, 0],
|
dist_func = cal_hamming_dis if dist_method == 'hamming' else cal_cosine_dis
|
||||||
# [1, 0, 0, 0],
|
|
||||||
# [0, 0, 1, 0]])
|
for query_index in range(query_number):
|
||||||
#
|
# (1, N)
|
||||||
# map = calc_map_k(qB, rB, query_L, retrieval_L)
|
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(np.float32)
|
||||||
# print(map)
|
hamming_dis = dist_func(query_feats[query_index, :], retrieval_feats) # (1, N)
|
||||||
|
|
||||||
|
# sort hamming distance
|
||||||
|
sort_index = np.argsort(hamming_dis)
|
||||||
|
|
||||||
|
# resort ground truth
|
||||||
|
ground_truth = ground_truth[sort_index]
|
||||||
|
|
||||||
|
# get top K ground truth
|
||||||
|
top_k_gnd = ground_truth[0:top_k]
|
||||||
|
top_k_sum = np.sum(top_k_gnd).astype(int) # the number of correct retrieval in top K
|
||||||
|
if top_k_sum == 0:
|
||||||
|
continue
|
||||||
|
count = np.linspace(1, top_k_sum, int(top_k_sum))
|
||||||
|
|
||||||
|
top_k_index = np.asarray(np.where(top_k_gnd == 1)) + 1.0
|
||||||
|
top_k_map += np.mean(count / top_k_index) # average precision of per class
|
||||||
|
|
||||||
|
return top_k_map / query_number # mean of average precision of all class
|
||||||
|
|
||||||
|
def cal_pr(retrieval_binary, query_binary, retrieval_label, query_label, interval=0.1):
|
||||||
|
r_arr = np.array([i * interval for i in range(1, int(1/interval) + 1)])
|
||||||
|
p_arr = np.zeros(len(r_arr))
|
||||||
|
|
||||||
|
query_number = query_label.shape[0]
|
||||||
|
|
||||||
|
for query_index in range(query_number):
|
||||||
|
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
|
||||||
|
np.float32) # (1, N)
|
||||||
|
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
|
||||||
|
|
||||||
|
# sort hamming distance
|
||||||
|
sort_index = np.argsort(hamming_dis)
|
||||||
|
ground_truth = ground_truth[sort_index]
|
||||||
|
tp_num = len(np.where(ground_truth == 1)[0])
|
||||||
|
r_num_arr = (tp_num * r_arr).astype(np.int32)
|
||||||
|
|
||||||
|
tp_cum = np.cumsum(ground_truth)
|
||||||
|
total_num_arr = np.array([np.where(tp_cum == i)[0][0] + 1 for i in r_num_arr])
|
||||||
|
p_arr += r_num_arr/total_num_arr
|
||||||
|
p_arr /= query_number
|
||||||
|
|
||||||
|
return np.array(list(zip(r_arr, p_arr)))
|
||||||
|
|
||||||
|
def cal_top_n(retrieval_binary, query_binary, retrieval_label, query_label, top_n=None):
|
||||||
|
if top_n is None:
|
||||||
|
top_n = range(100, 1001, 100)
|
||||||
|
|
||||||
|
top_n = np.array(top_n)
|
||||||
|
top_n_p = np.zeros(len(top_n))
|
||||||
|
query_number = query_label.shape[0]
|
||||||
|
|
||||||
|
for query_index in range(query_number):
|
||||||
|
ground_truth = (np.dot(query_label[query_index, :], retrieval_label.transpose()) > 0).astype(
|
||||||
|
np.float32) # (1, N)
|
||||||
|
hamming_dis = cal_hamming_dis(query_binary[query_index, :], retrieval_binary) # (1, N)
|
||||||
|
|
||||||
|
# sort hamming distance
|
||||||
|
sort_index = np.argsort(hamming_dis)
|
||||||
|
ground_truth = ground_truth[sort_index]
|
||||||
|
ground_truth = ground_truth[:top_n[-1]]
|
||||||
|
|
||||||
|
tp_cum = np.cumsum(ground_truth)
|
||||||
|
tp_num_arr = tp_cum[top_n - 1]
|
||||||
|
top_n_p += tp_num_arr/top_n
|
||||||
|
|
||||||
|
top_n_p /= query_number
|
||||||
|
return np.array(list(zip(top_n, top_n_p)))
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ def get_args():
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
parser.add_argument("--num-workers", type=int, default=4)
|
parser.add_argument("--num-workers", type=int, default=4)
|
||||||
parser.add_argument("--query-num", type=int, default=5120)
|
parser.add_argument("--query-num", type=int, default=5120)
|
||||||
parser.add_argument("--train-num", type=int, default=10240)
|
parser.add_argument("--train-num", type=int, default=1024)
|
||||||
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
parser.add_argument("--lr-decay-freq", type=int, default=5)
|
||||||
parser.add_argument("--display-step", type=int, default=50)
|
parser.add_argument("--display-step", type=int, default=50)
|
||||||
parser.add_argument("--seed", type=int, default=1814)
|
parser.add_argument("--seed", type=int, default=1814)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue