new update 2 modality
This commit is contained in:
parent
053a58b07a
commit
73c901a18c
|
|
@ -112,21 +112,19 @@ class Trainer(TrainBase):
|
|||
text_var_representation[str(centroid.astype(int))]= text_var[i]
|
||||
return text_representation, text_var_representation
|
||||
|
||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
def target_adv(self, image, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var
|
||||
,epsilon=0.03125, alpha=3/255, num_iter=1500):
|
||||
|
||||
delta = torch.zeros_like(image,requires_grad=True)
|
||||
# one=torch.zeros_like(positive)
|
||||
# alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||
for i in range(num_iter):
|
||||
self.model.zero_grad()
|
||||
anchor=self.model.encode_image(image+delta)
|
||||
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity())
|
||||
negative_dist=(anchor-negetive_mean)**2 / negative_var
|
||||
positive_dist=(anchor-positive_mean)**2 /positive_var
|
||||
negatives=torch.exp(negative_dist / temperature)
|
||||
positives= torch.exp(positive_dist / temperature)
|
||||
loss= torch.log(positives/(positives+negatives)).mean() + beta* loss1
|
||||
negatives=torch.exp(negative_dist / self.args.temperature)
|
||||
positives= torch.exp(positive_dist / self.args.temperature)
|
||||
loss= torch.log(positives/(positives+negatives)).mean() + self.args.beta* loss1
|
||||
loss.backward(retain_graph=True)
|
||||
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||
|
|
@ -192,7 +190,6 @@ class Trainer(TrainBase):
|
|||
'r_txt': retrieval_txt,
|
||||
'adv_l': adv_labels,
|
||||
'r_l': retrieval_labels
|
||||
# 'q_l':query_labels
|
||||
# 'pr': pr,
|
||||
# 'pr_t': pr_t
|
||||
}
|
||||
|
|
@ -204,28 +201,9 @@ class Trainer(TrainBase):
|
|||
|
||||
|
||||
|
||||
def train(self):
|
||||
self.logger.info("Start train.")
|
||||
|
||||
for epoch in range(self.args.epochs):
|
||||
self.train_epoch(epoch)
|
||||
self.valid(epoch)
|
||||
self.save_model(epoch)
|
||||
|
||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||
|
||||
|
||||
|
||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||
|
||||
code = torch.stack(code)
|
||||
# print(code.shape)
|
||||
code = code.permute(1, 0, 2)
|
||||
hash_code = torch.argmax(code, dim=-1)
|
||||
hash_code[torch.where(hash_code == 0)] = -1
|
||||
hash_code = hash_code.float()
|
||||
|
||||
return hash_code
|
||||
|
||||
def get_code(self, data_loader, length: int):
|
||||
|
||||
|
|
@ -247,9 +225,7 @@ class Trainer(TrainBase):
|
|||
|
||||
|
||||
|
||||
def valid_attack(self,adv_images, texts, adv_labels):
|
||||
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -287,49 +263,6 @@ class Trainer(TrainBase):
|
|||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
# def valid(self, epoch):
|
||||
# self.logger.info("Valid.")
|
||||
# self.change_state(mode="valid")
|
||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# # print("get all code")
|
||||
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# # print("map map")
|
||||
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# if self.max_mapi2t < mAPi2t:
|
||||
# self.best_epoch_i = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
# if self.max_mapt2i < mAPt2i:
|
||||
# self.best_epoch_t = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
|
||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||
|
||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
query_img = query_img.cpu().detach().numpy()
|
||||
query_txt = query_txt.cpu().detach().numpy()
|
||||
retrieval_img = retrieval_img.cpu().detach().numpy()
|
||||
retrieval_txt = retrieval_txt.cpu().detach().numpy()
|
||||
query_labels = self.query_labels.numpy()
|
||||
retrieval_labels = self.retrieval_labels.numpy()
|
||||
|
||||
result_dict = {
|
||||
'q_img': query_img,
|
||||
'q_txt': query_txt,
|
||||
'r_img': retrieval_img,
|
||||
'r_txt': retrieval_txt,
|
||||
'q_l': query_labels,
|
||||
'r_l': retrieval_labels
|
||||
}
|
||||
scio.savemat(os.path.join(save_dir, str(self.args.output_dim) + "-ours-" + self.args.dataset + "-" + mode_name + ".mat"), result_dict)
|
||||
self.logger.info(f">>>>>> save best {mode_name} data!")
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -244,6 +244,15 @@ class Trainer(TrainBase):
|
|||
|
||||
return import_scores.sum(dim=-1)
|
||||
|
||||
def adv_loss(self,anchor, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var):
|
||||
loss1=F.triplet_margin_with_distance_loss(anchor, positive_code,negetive_code, distance_function=nn.CosineSimilarity(),reduction='none')
|
||||
negative_dist=(anchor-negetive_mean)**2 / negative_var
|
||||
positive_dist=(anchor-positive_mean)**2 /positive_var
|
||||
negatives=torch.exp(negative_dist / self.args.temperature)
|
||||
positives= torch.exp(positive_dist / self.args.temperature)
|
||||
loss= torch.log(positives/(positives+negatives)) + self.args.beta* loss1
|
||||
return loss
|
||||
|
||||
|
||||
def target_adv(self, text_tokens, negetive_code,negetive_mean,negative_var, positive_code,positive_mean,positive_var,
|
||||
beta=10 ,epsilon=0.03125, alpha=3/255, num_iter=1500, temperature=0.05):
|
||||
|
|
@ -296,8 +305,7 @@ class Trainer(TrainBase):
|
|||
replace_text_input = self.clip_tokenizer(replace_texts).to(device)
|
||||
replace_embeds = self.model.encode_text(replace_text_input)
|
||||
|
||||
criterion = torch.nn.KLDivLoss(reduction='none')
|
||||
loss = criterion(replace_embeds.log_softmax(dim=-1), clean_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
|
||||
loss = self.adv_loss(replace_embeds, negetive_code,negetive_mean,negative_var,positive_code,positive_mean,positive_var)
|
||||
loss = loss.sum(dim=-1)
|
||||
candidate_idx = loss.argmax()
|
||||
final_words[top_index[0]] = available_substitutes[candidate_idx]
|
||||
|
|
@ -375,28 +383,11 @@ class Trainer(TrainBase):
|
|||
|
||||
|
||||
|
||||
def train(self):
|
||||
self.logger.info("Start train.")
|
||||
|
||||
for epoch in range(self.args.epochs):
|
||||
self.train_epoch(epoch)
|
||||
self.valid(epoch)
|
||||
self.save_model(epoch)
|
||||
|
||||
self.logger.info(f">>>>>>> FINISHED >>>>>> Best epoch, I-T: {self.best_epoch_i}, mAP: {self.max_mapi2t}, T-I: {self.best_epoch_t}, mAP: {self.max_mapt2i}")
|
||||
|
||||
|
||||
|
||||
def make_hash_code(self, code: list) -> torch.Tensor:
|
||||
|
||||
code = torch.stack(code)
|
||||
# print(code.shape)
|
||||
code = code.permute(1, 0, 2)
|
||||
hash_code = torch.argmax(code, dim=-1)
|
||||
hash_code[torch.where(hash_code == 0)] = -1
|
||||
hash_code = hash_code.float()
|
||||
|
||||
return hash_code
|
||||
|
||||
def get_code(self, data_loader, length: int):
|
||||
|
||||
|
|
@ -456,28 +447,6 @@ class Trainer(TrainBase):
|
|||
self.logger.info(">>>>>> save all data!")
|
||||
|
||||
|
||||
# def valid(self, epoch):
|
||||
# self.logger.info("Valid.")
|
||||
# self.change_state(mode="valid")
|
||||
# query_img, query_txt = self.get_code(self.query_loader, self.args.query_num) if self.args.hash_layer == "select" else super().get_code(self.query_loader, self.args.query_num)
|
||||
# retrieval_img, retrieval_txt = self.get_code(self.retrieval_loader, self.args.retrieval_num) if self.args.hash_layer == "select" else super().get_code(self.retrieval_loader, self.args.retrieval_num)
|
||||
# # print("get all code")
|
||||
# mAPi2t = calc_map_k(query_img, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# # print("map map")
|
||||
# mAPt2i = calc_map_k(query_txt, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPi2i = calc_map_k(query_img, retrieval_img, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# mAPt2t = calc_map_k(query_txt, retrieval_txt, self.query_labels, self.retrieval_labels, None, self.rank)
|
||||
# if self.max_mapi2t < mAPi2t:
|
||||
# self.best_epoch_i = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t")
|
||||
# self.max_mapi2t = max(self.max_mapi2t, mAPi2t)
|
||||
# if self.max_mapt2i < mAPt2i:
|
||||
# self.best_epoch_t = epoch
|
||||
# self.save_mat(query_img, query_txt, retrieval_img, retrieval_txt, mode_name="t2i")
|
||||
# self.max_mapt2i = max(self.max_mapt2i, mAPt2i)
|
||||
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}], MAP(i->t): {mAPi2t}, MAP(t->i): {mAPt2i}, MAP(t->t): {mAPt2t}, MAP(i->i): {mAPi2i}, \
|
||||
# MAX MAP(i->t): {self.max_mapi2t}, MAX MAP(t->i): {self.max_mapt2i}")
|
||||
|
||||
def save_mat(self, query_img, query_txt, retrieval_img, retrieval_txt, mode_name="i2t"):
|
||||
|
||||
save_dir = os.path.join(self.args.save_dir, "PR_cruve")
|
||||
|
|
|
|||
Loading…
Reference in New Issue