This commit is contained in:
leewlving 2024-05-20 15:23:06 +08:00
parent 79bd5e6ac8
commit e79be260b4
5 changed files with 574 additions and 106 deletions

310
model/GAN.py Normal file
View File

@ -0,0 +1,310 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from model.spectral_norm import spectral_norm as SpectralNorm
class PrototypeNet(nn.Module):
def __init__(self, bit, num_classes):
super(PrototypeNet, self).__init__()
self.feature = nn.Sequential(nn.Linear(num_classes, 4096),
nn.ReLU(True), nn.Linear(4096, 512))
self.hashing = nn.Sequential(nn.Linear(512, bit), nn.Tanh())
self.classifier = nn.Sequential(nn.Linear(512, num_classes),
nn.Sigmoid())
def forward(self, label):
f = self.feature(label)
h = self.hashing(f)
c = self.classifier(f)
return f, h, c
class Discriminator(nn.Module):
"""
Discriminator network with PatchGAN.
Reference: https://github.com/yunjey/stargan/blob/master/model.py
"""
def __init__(self, num_classes, image_size=224, conv_dim=64, repeat_num=5):
super(Discriminator, self).__init__()
layers = []
layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / (2**repeat_num))
self.main = nn.Sequential(*layers)
self.fc = nn.Conv2d(curr_dim, num_classes + 1, kernel_size=kernel_size, bias=False)
def forward(self, x):
h = self.main(x)
out = self.fc(h)
return out.squeeze()
class Generator(nn.Module):
"""Generator: Encoder-Decoder Architecture.
Reference: https://github.com/yunjey/stargan/blob/master/model.py
"""
def __init__(self):
super(Generator, self).__init__()
# Label Encoder
self.label_encoder = LabelEncoder()
# Image Encoder
curr_dim = 64
image_encoder = [
nn.Conv2d(6, curr_dim, kernel_size=7, stride=1, padding=3, bias=True),
nn.InstanceNorm2d(curr_dim),
nn.ReLU(inplace=True)
]
# Down Sampling
for i in range(2):
image_encoder += [
nn.Conv2d(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias=True),
nn.InstanceNorm2d(curr_dim * 2),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim * 2
# Bottleneck
for i in range(3):
image_encoder += [
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
]
self.image_encoder = nn.Sequential(*image_encoder)
# Decoder
decoder = []
# Bottleneck
for i in range(3):
decoder += [
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
]
# Up Sampling
for i in range(2):
decoder += [
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.InstanceNorm2d(curr_dim // 2),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim // 2
self.residual = nn.Sequential(
# nn.Conv2d(curr_dim + 3,
# curr_dim,
# kernel_size=3,
# stride=1,
# padding=1,
# bias=False),
# nn.InstanceNorm2d(curr_dim // 2, affine=False),
# nn.ReLU(inplace=True),
nn.Conv2d(curr_dim + 3,
3,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.Tanh())
self.decoder = nn.Sequential(*decoder)
def forward(self, x, label_feature):
mixed_feature = self.label_encoder(x, label_feature)
encode = self.image_encoder(mixed_feature)
decode = self.decoder(encode)
decode_x = torch.cat([decode, x], dim=1)
adv_x = self.residual(decode_x)
return adv_x, mixed_feature
class LabelEncoder(nn.Module):
def __init__(self, nf=128):
super(LabelEncoder, self).__init__()
self.nf = nf
curr_dim = nf
self.size = 14
self.fc = nn.Sequential(
# nn.Linear(512, 512), nn.ReLU(True),
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
transform = []
for i in range(4):
transform += [
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
# nn.Upsample(scale_factor=(2, 2)),
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(curr_dim // 2, affine=False),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim // 2
transform += [
nn.Conv2d(curr_dim,
3,
kernel_size=3,
stride=1,
padding=1,
bias=False)
]
self.transform = nn.Sequential(*transform)
def forward(self, image, label_feature):
label_feature = self.fc(label_feature)
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
label_feature = self.transform(label_feature)
# mixed_feature = label_feature + image
mixed_feature = torch.cat((label_feature, image), dim=1)
return mixed_feature
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out, net_mode=None):
if net_mode == 'p' or (net_mode is None):
use_affine = True
elif net_mode == 't':
use_affine = False
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.InstanceNorm2d(dim_out,
affine=use_affine),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.InstanceNorm2d(dim_out,
affine=use_affine))
def forward(self, x):
return x + self.main(x)
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=0.0, target_fake_label=1.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, label, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
real_label = self.real_label.expand(label.size(0), 1)
target_tensor = torch.cat([label, real_label], dim=-1)
else:
fake_label = self.fake_label.expand(label.size(0), 1)
target_tensor = torch.cat([label, fake_label], dim=-1)
return target_tensor
def __call__(self, prediction, label, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(label, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count -
opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer,
step_size=opt.lr_decay_iters,
gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.2,
threshold=0.01,
patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
T_max=opt.n_epochs,
eta_min=0)
else:
return NotImplementedError(
'learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler

89
model/spectral_norm.py Normal file
View File

@ -0,0 +1,89 @@
import torch
from torch.nn import Parameter
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(object):
def __init__(self):
self.name = "weight"
#print(self.name)
self.power_iterations = 1
def compute_weight(self, module):
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
w = getattr(module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(
torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
return w / sigma.expand_as(w)
@staticmethod
def apply(module):
name = "weight"
fn = SpectralNorm()
try:
u = getattr(module, name + "_u")
v = getattr(module, name + "_v")
w = getattr(module, name + "_bar")
except AttributeError:
w = getattr(module, name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1),
requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
w_bar = Parameter(w.data)
#del module._parameters[name]
module.register_parameter(name + "_u", u)
module.register_parameter(name + "_v", v)
module.register_parameter(name + "_bar", w_bar)
# remove w from parameter list
del module._parameters[name]
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_u']
del module._parameters[self.name + '_v']
del module._parameters[self.name + '_bar']
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module):
SpectralNorm.apply(module)
return module
def remove_spectral_norm(module):
name = 'weight'
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))

View File

@ -6,13 +6,23 @@ import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import scipy.io as scio
import numpy as np
from .base import TrainBase
from model.optimization import BertAdam
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader
import open_clip
# from transformers import BertModel
def clamp(delta, clean_imgs):
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
clamp_delta = clamp_imgs - clean_imgs.data
return clamp_delta
class Trainer(TrainBase):
@ -21,32 +31,62 @@ class Trainer(TrainBase):
args = get_args()
super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
self.run()
image_representation, text_representation=self.generate_mapping()
self.image_representation=image_representation
self.text_representation=text_representation
# self.run()
def _init_model(self):
self.logger.info("init model.")
linear = False
if self.args.hash_layer == "linear":
linear = True
self.logger.info("ViT+GPT!")
HashModel = DCMHT
self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
self.logger.info("load pretrained model.")
self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# self.generator=Generator()
# linear = False
# if self.args.hash_layer == "linear":
# linear = True
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
# self.bert.eval()
# self.logger.info("ViT+GPT!")
# HashModel = DCMHT
# if self.args.victim_model == 'JDSH':
# from model.JDSH import TxtNet, ImgNet
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
# checkpoint=torch.load(path)
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
# self.img_model.eval()
# self.txt_model.eval()
# elif self.args.victim_model == 'DJSRH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'SSAH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'DCHUC':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
# self.logger.info("load pretrained model.")
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device)
self.model= model_clip
self.model.eval()
self.model.float()
self.optimizer = BertAdam([
{'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
{'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
{'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
weight_decay=self.args.weight_decay, max_grad_norm=1.0)
# self.optimizer = BertAdam([
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
print(self.model)
# print(self.model)
def _init_dataset(self):
self.logger.info("init dataset.")
@ -89,12 +129,62 @@ class Trainer(TrainBase):
pin_memory=True,
shuffle=True
)
def generate_mapping(self):
text_train=[]
label_train=[]
# image_train=[]
# self.change_state(mode="valid")
for image, text, label, index in self.train_loader:
# image=image.to(self.device, non_blocking=True)
text=text.to(self.device, non_blocking=True)
temp_text=self.model.encode_text(text)
# temp_image=self.model.encode_image(image)
# image_train.append(temp_image.cpu().detach().numpy())
text_train.append(temp_text.cpu().detach().numpy())
label_train.append(label.detach().numpy())
text_train=np.concatenate(text_train, axis=0)
# image_train=np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
text_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[centroid.tobytes()] = text_centroids[i]
text_var_representation[centroid.tobytes()]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, positive, negative,
epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
# clean_output = self.model.encode_image(image)
one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss=alienation_loss(anchor, positive, negative)
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_()
return delta.detach()
def train_epoch(self, epoch):
self.change_state(mode="train")
self.change_state(mode="valid")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0
times = 0
adv_images=[]
adv_labels=[]
texts=[]
for image, text, label, index in self.train_loader:
self.global_step += 1
times += 1
@ -102,27 +192,29 @@ class Trainer(TrainBase):
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag()
# print(text.dtype)
# text.float()
# label.float()
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
# print("text shape:", text.shape)
index = index.numpy()
image_anchor=self.image_representation(label.detach().cpu().numpy())
text_anchor=self.text_representation(label.detach().cpu().numpy())
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
target_label=label.flip(dims=[0])
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# print("text shape:", text.shape)
# index = index.numpy()
# print(text.shape)
hash_img, hash_text = self.model(image, text)
if self.args.hash_layer == "select":
hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1)
hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1)
loss = self.compute_loss(hash_img, hash_text, label, epoch, times)
all_loss += loss
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
adv_image=delta+image
adv_images.append(adv_image)
adv_labels.append(target_label)
texts.append(text)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return adv_images, texts, adv_labels
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
def train(self):
self.logger.info("Start train.")
@ -204,78 +296,40 @@ class Trainer(TrainBase):
image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True)
index = index.numpy()
image_hash = self.model.encode_image(image)
image_hash = self.make_hash_code(image_hash)
text_hash = self.model.encode_text(text)
text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def our_loss(self, image, text, label, epoch, times):
loss = 0
label_sim = calc_neighbor(label, label)
if image.is_cuda:
label_sim = label_sim.to(image.device)
intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim)
inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim)
inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim)
intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss)
inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t
similarity_loss = inter_similarity_loss + intra_similarity_loss
# if self.writer is not None:
# self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step)
# self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step)
# self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step)
# self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step)
# self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step)
# self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step)
# self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step)
# self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step)
# self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step)
# self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step)
# self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step)
# self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step)
# self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step)
# self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step)
# self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step)
if self.args.hash_layer != "select":
quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2
loss = similarity_loss + quantization_loss
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
f"QUATIZATION LOSS, {quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
else:
loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss)
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
# f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return loss
def compute_loss(self, image, text, label, epoch, times):
loss = self.our_loss(image, text, label, epoch, times)
def get_adv_code(self, adv_data_list,text_list):
return loss
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
for i in tqdm(range(len(adv_data_list))):
image = adv_data_list[i].to(self.rank, non_blocking=True)
text = text_list[i].to(self.rank, non_blocking=True)
# index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
# text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[i, :] = image_hash.data
text_buffer[i, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"):
if self.args.pretrained == "":

View File

@ -18,7 +18,7 @@ def get_args():
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=64)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=32)

View File

@ -227,3 +227,18 @@ class CrossEn_mean(nn.Module):
# print(logpt.max())
sim_loss = sim_matrix.mean()
return sim_loss
def find_indices(array, b):
# Create a boolean mask where the first dimension of the array equals b
mask = np.all(array[:, :] == b, axis=1)
# Find the indices where the mask is True
indices = np.argwhere(mask).flatten()
return indices
# def get_var(data_list):
# mean=data_list.mean(axis=0)
# sq_var=torch.zeros_like(data_list[0])
# for i in data_list:
# sq_var+=