This commit is contained in:
leewlving 2024-05-20 15:23:06 +08:00
parent 79bd5e6ac8
commit e79be260b4
5 changed files with 574 additions and 106 deletions

310
model/GAN.py Normal file
View File

@ -0,0 +1,310 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from model.spectral_norm import spectral_norm as SpectralNorm
class PrototypeNet(nn.Module):
def __init__(self, bit, num_classes):
super(PrototypeNet, self).__init__()
self.feature = nn.Sequential(nn.Linear(num_classes, 4096),
nn.ReLU(True), nn.Linear(4096, 512))
self.hashing = nn.Sequential(nn.Linear(512, bit), nn.Tanh())
self.classifier = nn.Sequential(nn.Linear(512, num_classes),
nn.Sigmoid())
def forward(self, label):
f = self.feature(label)
h = self.hashing(f)
c = self.classifier(f)
return f, h, c
class Discriminator(nn.Module):
"""
Discriminator network with PatchGAN.
Reference: https://github.com/yunjey/stargan/blob/master/model.py
"""
def __init__(self, num_classes, image_size=224, conv_dim=64, repeat_num=5):
super(Discriminator, self).__init__()
layers = []
layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
layers.append(nn.LeakyReLU(0.01))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))
layers.append(nn.LeakyReLU(0.01))
curr_dim = curr_dim * 2
kernel_size = int(image_size / (2**repeat_num))
self.main = nn.Sequential(*layers)
self.fc = nn.Conv2d(curr_dim, num_classes + 1, kernel_size=kernel_size, bias=False)
def forward(self, x):
h = self.main(x)
out = self.fc(h)
return out.squeeze()
class Generator(nn.Module):
"""Generator: Encoder-Decoder Architecture.
Reference: https://github.com/yunjey/stargan/blob/master/model.py
"""
def __init__(self):
super(Generator, self).__init__()
# Label Encoder
self.label_encoder = LabelEncoder()
# Image Encoder
curr_dim = 64
image_encoder = [
nn.Conv2d(6, curr_dim, kernel_size=7, stride=1, padding=3, bias=True),
nn.InstanceNorm2d(curr_dim),
nn.ReLU(inplace=True)
]
# Down Sampling
for i in range(2):
image_encoder += [
nn.Conv2d(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias=True),
nn.InstanceNorm2d(curr_dim * 2),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim * 2
# Bottleneck
for i in range(3):
image_encoder += [
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
]
self.image_encoder = nn.Sequential(*image_encoder)
# Decoder
decoder = []
# Bottleneck
for i in range(3):
decoder += [
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
]
# Up Sampling
for i in range(2):
decoder += [
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
nn.InstanceNorm2d(curr_dim // 2),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim // 2
self.residual = nn.Sequential(
# nn.Conv2d(curr_dim + 3,
# curr_dim,
# kernel_size=3,
# stride=1,
# padding=1,
# bias=False),
# nn.InstanceNorm2d(curr_dim // 2, affine=False),
# nn.ReLU(inplace=True),
nn.Conv2d(curr_dim + 3,
3,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.Tanh())
self.decoder = nn.Sequential(*decoder)
def forward(self, x, label_feature):
mixed_feature = self.label_encoder(x, label_feature)
encode = self.image_encoder(mixed_feature)
decode = self.decoder(encode)
decode_x = torch.cat([decode, x], dim=1)
adv_x = self.residual(decode_x)
return adv_x, mixed_feature
class LabelEncoder(nn.Module):
def __init__(self, nf=128):
super(LabelEncoder, self).__init__()
self.nf = nf
curr_dim = nf
self.size = 14
self.fc = nn.Sequential(
# nn.Linear(512, 512), nn.ReLU(True),
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
transform = []
for i in range(4):
transform += [
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
# nn.Upsample(scale_factor=(2, 2)),
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(curr_dim // 2, affine=False),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim // 2
transform += [
nn.Conv2d(curr_dim,
3,
kernel_size=3,
stride=1,
padding=1,
bias=False)
]
self.transform = nn.Sequential(*transform)
def forward(self, image, label_feature):
label_feature = self.fc(label_feature)
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
label_feature = self.transform(label_feature)
# mixed_feature = label_feature + image
mixed_feature = torch.cat((label_feature, image), dim=1)
return mixed_feature
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out, net_mode=None):
if net_mode == 'p' or (net_mode is None):
use_affine = True
elif net_mode == 't':
use_affine = False
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.InstanceNorm2d(dim_out,
affine=use_affine),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias=False), nn.InstanceNorm2d(dim_out,
affine=use_affine))
def forward(self, x):
return x + self.main(x)
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=0.0, target_fake_label=1.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, label, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
real_label = self.real_label.expand(label.size(0), 1)
target_tensor = torch.cat([label, real_label], dim=-1)
else:
fake_label = self.fake_label.expand(label.size(0), 1)
target_tensor = torch.cat([label, fake_label], dim=-1)
return target_tensor
def __call__(self, prediction, label, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(label, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count -
opt.n_epochs) / float(opt.n_epochs_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer,
step_size=opt.lr_decay_iters,
gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.2,
threshold=0.01,
patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
T_max=opt.n_epochs,
eta_min=0)
else:
return NotImplementedError(
'learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler

89
model/spectral_norm.py Normal file
View File

@ -0,0 +1,89 @@
import torch
from torch.nn import Parameter
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(object):
def __init__(self):
self.name = "weight"
#print(self.name)
self.power_iterations = 1
def compute_weight(self, module):
u = getattr(module, self.name + "_u")
v = getattr(module, self.name + "_v")
w = getattr(module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(
torch.mv(torch.t(w.view(height, -1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
return w / sigma.expand_as(w)
@staticmethod
def apply(module):
name = "weight"
fn = SpectralNorm()
try:
u = getattr(module, name + "_u")
v = getattr(module, name + "_v")
w = getattr(module, name + "_bar")
except AttributeError:
w = getattr(module, name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1),
requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
w_bar = Parameter(w.data)
#del module._parameters[name]
module.register_parameter(name + "_u", u)
module.register_parameter(name + "_v", v)
module.register_parameter(name + "_bar", w_bar)
# remove w from parameter list
del module._parameters[name]
setattr(module, name, fn.compute_weight(module))
# recompute weight before every forward()
module.register_forward_pre_hook(fn)
return fn
def remove(self, module):
weight = self.compute_weight(module)
delattr(module, self.name)
del module._parameters[self.name + '_u']
del module._parameters[self.name + '_v']
del module._parameters[self.name + '_bar']
module.register_parameter(self.name, Parameter(weight.data))
def __call__(self, module, inputs):
setattr(module, self.name, self.compute_weight(module))
def spectral_norm(module):
SpectralNorm.apply(module)
return module
def remove_spectral_norm(module):
name = 'weight'
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))

View File

@ -6,13 +6,23 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import scipy.io as scio import scipy.io as scio
import numpy as np
from .base import TrainBase from .base import TrainBase
from model.optimization import BertAdam from model.optimization import BertAdam
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity # from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
from utils.calc_utils import calc_map_k_matrix as calc_map_k from utils.calc_utils import calc_map_k_matrix as calc_map_k
from dataset.dataloader import dataloader from dataset.dataloader import dataloader
import open_clip
# from transformers import BertModel
def clamp(delta, clean_imgs):
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
clamp_delta = clamp_imgs - clean_imgs.data
return clamp_delta
class Trainer(TrainBase): class Trainer(TrainBase):
@ -21,32 +31,62 @@ class Trainer(TrainBase):
args = get_args() args = get_args()
super(Trainer, self).__init__(args, rank) super(Trainer, self).__init__(args, rank)
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset))) self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
self.run() image_representation, text_representation=self.generate_mapping()
self.image_representation=image_representation
self.text_representation=text_representation
# self.run()
def _init_model(self): def _init_model(self):
self.logger.info("init model.") self.logger.info("init model.")
linear = False # self.generator=Generator()
if self.args.hash_layer == "linear": # linear = False
linear = True # if self.args.hash_layer == "linear":
# linear = True
self.logger.info("ViT+GPT!") # self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
HashModel = DCMHT # self.bert.eval()
self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path, # self.logger.info("ViT+GPT!")
writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank) # HashModel = DCMHT
if self.args.pretrained != "" and os.path.exists(self.args.pretrained): # if self.args.victim_model == 'JDSH':
self.logger.info("load pretrained model.") # from model.JDSH import TxtNet, ImgNet
self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}")) # # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
# checkpoint=torch.load(path)
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
# self.img_model.eval()
# self.txt_model.eval()
# elif self.args.victim_model == 'DJSRH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'SSAH':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# elif self.args.victim_model == 'DCHUC':
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
# self.logger.info("load pretrained model.")
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device)
self.model= model_clip
self.model.eval()
self.model.float() self.model.float()
self.optimizer = BertAdam([ # self.optimizer = BertAdam([
{'params': self.model.clip.parameters(), 'lr': self.args.clip_lr}, # {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
{'params': self.model.image_hash.parameters(), 'lr': self.args.lr}, # {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
{'params': self.model.text_hash.parameters(), 'lr': self.args.lr} # {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine', # ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs, # b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
weight_decay=self.args.weight_decay, max_grad_norm=1.0) # weight_decay=self.args.weight_decay, max_grad_norm=1.0)
print(self.model) # print(self.model)
def _init_dataset(self): def _init_dataset(self):
self.logger.info("init dataset.") self.logger.info("init dataset.")
@ -89,12 +129,62 @@ class Trainer(TrainBase):
pin_memory=True, pin_memory=True,
shuffle=True shuffle=True
) )
def generate_mapping(self):
text_train=[]
label_train=[]
# image_train=[]
# self.change_state(mode="valid")
for image, text, label, index in self.train_loader:
# image=image.to(self.device, non_blocking=True)
text=text.to(self.device, non_blocking=True)
temp_text=self.model.encode_text(text)
# temp_image=self.model.encode_image(image)
# image_train.append(temp_image.cpu().detach().numpy())
text_train.append(temp_text.cpu().detach().numpy())
label_train.append(label.detach().numpy())
text_train=np.concatenate(text_train, axis=0)
# image_train=np.concatenate(image_train, axis=0)
label_train=np.concatenate(label_train, axis=0)
label_unipue=np.unique(label_train,axis=0)
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
text_representation = {}
text_var_representation = {}
for i, centroid in enumerate(label_unipue):
text_representation[centroid.tobytes()] = text_centroids[i]
text_var_representation[centroid.tobytes()]= text_var[i]
return text_representation, text_var_representation
def target_adv(self, image, positive, negative,
epsilon=0.03125, alpha=3/255, num_iter=100):
delta = torch.zeros_like(image,requires_grad=True)
# clean_output = self.model.encode_image(image)
one=torch.zeros_like(positive)
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
for i in range(num_iter):
self.model.zero_grad()
anchor=self.model.encode_image(image+delta)
loss=alienation_loss(anchor, positive, negative)
loss.backward(retain_graph=True)
delta.data = delta - alpha * delta.grad.detach().sign()
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
delta.grad.zero_()
return delta.detach()
def train_epoch(self, epoch): def train_epoch(self, epoch):
self.change_state(mode="train") self.change_state(mode="valid")
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs)) self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
all_loss = 0 all_loss = 0
times = 0 times = 0
adv_images=[]
adv_labels=[]
texts=[]
for image, text, label, index in self.train_loader: for image, text, label, index in self.train_loader:
self.global_step += 1 self.global_step += 1
times += 1 times += 1
@ -102,27 +192,29 @@ class Trainer(TrainBase):
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]: if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
label = torch.ones([image.shape[0]], dtype=torch.int) label = torch.ones([image.shape[0]], dtype=torch.int)
label = label.diag() label = label.diag()
# print(text.dtype)
# text.float()
# label.float()
image = image.to(self.rank, non_blocking=True) image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True)
# print("text shape:", text.shape)
index = index.numpy() index = index.numpy()
image_anchor=self.image_representation(label.detach().cpu().numpy())
text_anchor=self.text_representation(label.detach().cpu().numpy())
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
target_label=label.flip(dims=[0])
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
# print("text shape:", text.shape)
# index = index.numpy()
# print(text.shape) # print(text.shape)
hash_img, hash_text = self.model(image, text) delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
if self.args.hash_layer == "select": torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1) adv_image=delta+image
hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1) adv_images.append(adv_image)
loss = self.compute_loss(hash_img, hash_text, label, epoch, times) adv_labels.append(target_label)
all_loss += loss texts.append(text)
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return adv_images, texts, adv_labels
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
def train(self): def train(self):
self.logger.info("Start train.") self.logger.info("Start train.")
@ -204,78 +296,40 @@ class Trainer(TrainBase):
image = image.to(self.rank, non_blocking=True) image = image.to(self.rank, non_blocking=True)
text = text.to(self.rank, non_blocking=True) text = text.to(self.rank, non_blocking=True)
index = index.numpy() index = index.numpy()
image_hash = self.model.encode_image(image) image_hash=self.img_model(image)
image_hash = self.make_hash_code(image_hash) text_feat=self.bert(text)[0]
text_hash = self.model.encode_text(text) text_hash=self.txt_model(text_feat)
text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[index, :] = image_hash.data img_buffer[index, :] = image_hash.data
text_buffer[index, :] = text_hash.data text_buffer[index, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank) return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def our_loss(self, image, text, label, epoch, times):
loss = 0
label_sim = calc_neighbor(label, label)
if image.is_cuda:
label_sim = label_sim.to(image.device)
intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim)
inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim)
inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim)
intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss)
inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t
similarity_loss = inter_similarity_loss + intra_similarity_loss
# if self.writer is not None:
# self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step)
# self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step)
# self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step)
# self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step)
# self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step)
# self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step)
# self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step)
# self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step)
# self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step)
# self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step)
# self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step)
# self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step)
# self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step)
# self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step)
# self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step)
if self.args.hash_layer != "select":
quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2
loss = similarity_loss + quantization_loss
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
f"QUATIZATION LOSS, {quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
else:
loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss)
if self.global_step % self.args.display_step == 0:
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
# f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
return loss
def compute_loss(self, image, text, label, epoch, times): def get_adv_code(self, adv_data_list,text_list):
loss = self.our_loss(image, text, label, epoch, times)
return loss img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
for i in tqdm(range(len(adv_data_list))):
image = adv_data_list[i].to(self.rank, non_blocking=True)
text = text_list[i].to(self.rank, non_blocking=True)
# index = index.numpy()
image_hash=self.img_model(image)
text_feat=self.bert(text)[0]
text_hash=self.txt_model(text_feat)
# text_hash = self.make_hash_code(text_hash)
# image_hash.to(self.rank)
# text_hash.to(self.rank)
img_buffer[i, :] = image_hash.data
text_buffer[i, :] = text_hash.data
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
def valid_attack(self,adv_images, texts, adv_labels):
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
os.makedirs(save_dir, exist_ok=True)
def test(self, mode_name="i2t"): def test(self, mode_name="i2t"):
if self.args.pretrained == "": if self.args.pretrained == "":

View File

@ -18,7 +18,7 @@ def get_args():
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat") # parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat") # parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat") # parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
parser.add_argument("--txt-dim", type=int, default=1024)
parser.add_argument("--output-dim", type=int, default=64) parser.add_argument("--output-dim", type=int, default=64)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--max-words", type=int, default=32) parser.add_argument("--max-words", type=int, default=32)

View File

@ -227,3 +227,18 @@ class CrossEn_mean(nn.Module):
# print(logpt.max()) # print(logpt.max())
sim_loss = sim_matrix.mean() sim_loss = sim_matrix.mean()
return sim_loss return sim_loss
def find_indices(array, b):
# Create a boolean mask where the first dimension of the array equals b
mask = np.all(array[:, :] == b, axis=1)
# Find the indices where the mask is True
indices = np.argwhere(mask).flatten()
return indices
# def get_var(data_list):
# mean=data_list.mean(axis=0)
# sq_var=torch.zeros_like(data_list[0])
# for i in data_list:
# sq_var+=