add init
This commit is contained in:
parent
79bd5e6ac8
commit
e79be260b4
|
|
@ -0,0 +1,310 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
from model.spectral_norm import spectral_norm as SpectralNorm
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypeNet(nn.Module):
|
||||||
|
def __init__(self, bit, num_classes):
|
||||||
|
super(PrototypeNet, self).__init__()
|
||||||
|
|
||||||
|
self.feature = nn.Sequential(nn.Linear(num_classes, 4096),
|
||||||
|
nn.ReLU(True), nn.Linear(4096, 512))
|
||||||
|
self.hashing = nn.Sequential(nn.Linear(512, bit), nn.Tanh())
|
||||||
|
self.classifier = nn.Sequential(nn.Linear(512, num_classes),
|
||||||
|
nn.Sigmoid())
|
||||||
|
|
||||||
|
def forward(self, label):
|
||||||
|
f = self.feature(label)
|
||||||
|
h = self.hashing(f)
|
||||||
|
c = self.classifier(f)
|
||||||
|
return f, h, c
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Discriminator network with PatchGAN.
|
||||||
|
Reference: https://github.com/yunjey/stargan/blob/master/model.py
|
||||||
|
"""
|
||||||
|
def __init__(self, num_classes, image_size=224, conv_dim=64, repeat_num=5):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
layers = []
|
||||||
|
layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
|
||||||
|
layers.append(nn.LeakyReLU(0.01))
|
||||||
|
|
||||||
|
curr_dim = conv_dim
|
||||||
|
for i in range(1, repeat_num):
|
||||||
|
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)))
|
||||||
|
layers.append(nn.LeakyReLU(0.01))
|
||||||
|
curr_dim = curr_dim * 2
|
||||||
|
|
||||||
|
kernel_size = int(image_size / (2**repeat_num))
|
||||||
|
self.main = nn.Sequential(*layers)
|
||||||
|
self.fc = nn.Conv2d(curr_dim, num_classes + 1, kernel_size=kernel_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = self.main(x)
|
||||||
|
out = self.fc(h)
|
||||||
|
return out.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
"""Generator: Encoder-Decoder Architecture.
|
||||||
|
Reference: https://github.com/yunjey/stargan/blob/master/model.py
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
|
# Label Encoder
|
||||||
|
self.label_encoder = LabelEncoder()
|
||||||
|
|
||||||
|
# Image Encoder
|
||||||
|
curr_dim = 64
|
||||||
|
image_encoder = [
|
||||||
|
nn.Conv2d(6, curr_dim, kernel_size=7, stride=1, padding=3, bias=True),
|
||||||
|
nn.InstanceNorm2d(curr_dim),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
# Down Sampling
|
||||||
|
for i in range(2):
|
||||||
|
image_encoder += [
|
||||||
|
nn.Conv2d(curr_dim,
|
||||||
|
curr_dim * 2,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=True),
|
||||||
|
nn.InstanceNorm2d(curr_dim * 2),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
curr_dim = curr_dim * 2
|
||||||
|
# Bottleneck
|
||||||
|
for i in range(3):
|
||||||
|
image_encoder += [
|
||||||
|
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
|
||||||
|
]
|
||||||
|
self.image_encoder = nn.Sequential(*image_encoder)
|
||||||
|
|
||||||
|
# Decoder
|
||||||
|
decoder = []
|
||||||
|
# Bottleneck
|
||||||
|
for i in range(3):
|
||||||
|
decoder += [
|
||||||
|
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='t')
|
||||||
|
]
|
||||||
|
# Up Sampling
|
||||||
|
for i in range(2):
|
||||||
|
decoder += [
|
||||||
|
nn.ConvTranspose2d(curr_dim,
|
||||||
|
curr_dim // 2,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
nn.InstanceNorm2d(curr_dim // 2),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
curr_dim = curr_dim // 2
|
||||||
|
self.residual = nn.Sequential(
|
||||||
|
# nn.Conv2d(curr_dim + 3,
|
||||||
|
# curr_dim,
|
||||||
|
# kernel_size=3,
|
||||||
|
# stride=1,
|
||||||
|
# padding=1,
|
||||||
|
# bias=False),
|
||||||
|
# nn.InstanceNorm2d(curr_dim // 2, affine=False),
|
||||||
|
# nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(curr_dim + 3,
|
||||||
|
3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False), nn.Tanh())
|
||||||
|
self.decoder = nn.Sequential(*decoder)
|
||||||
|
|
||||||
|
def forward(self, x, label_feature):
|
||||||
|
mixed_feature = self.label_encoder(x, label_feature)
|
||||||
|
encode = self.image_encoder(mixed_feature)
|
||||||
|
decode = self.decoder(encode)
|
||||||
|
decode_x = torch.cat([decode, x], dim=1)
|
||||||
|
adv_x = self.residual(decode_x)
|
||||||
|
return adv_x, mixed_feature
|
||||||
|
|
||||||
|
|
||||||
|
class LabelEncoder(nn.Module):
|
||||||
|
def __init__(self, nf=128):
|
||||||
|
super(LabelEncoder, self).__init__()
|
||||||
|
self.nf = nf
|
||||||
|
curr_dim = nf
|
||||||
|
self.size = 14
|
||||||
|
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
# nn.Linear(512, 512), nn.ReLU(True),
|
||||||
|
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
|
||||||
|
|
||||||
|
transform = []
|
||||||
|
for i in range(4):
|
||||||
|
transform += [
|
||||||
|
nn.ConvTranspose2d(curr_dim,
|
||||||
|
curr_dim // 2,
|
||||||
|
kernel_size=4,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
bias=False),
|
||||||
|
# nn.Upsample(scale_factor=(2, 2)),
|
||||||
|
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
|
||||||
|
nn.InstanceNorm2d(curr_dim // 2, affine=False),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
]
|
||||||
|
curr_dim = curr_dim // 2
|
||||||
|
|
||||||
|
transform += [
|
||||||
|
nn.Conv2d(curr_dim,
|
||||||
|
3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
]
|
||||||
|
self.transform = nn.Sequential(*transform)
|
||||||
|
|
||||||
|
def forward(self, image, label_feature):
|
||||||
|
label_feature = self.fc(label_feature)
|
||||||
|
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
|
||||||
|
label_feature = self.transform(label_feature)
|
||||||
|
|
||||||
|
# mixed_feature = label_feature + image
|
||||||
|
mixed_feature = torch.cat((label_feature, image), dim=1)
|
||||||
|
return mixed_feature
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
"""Residual Block."""
|
||||||
|
def __init__(self, dim_in, dim_out, net_mode=None):
|
||||||
|
if net_mode == 'p' or (net_mode is None):
|
||||||
|
use_affine = True
|
||||||
|
elif net_mode == 't':
|
||||||
|
use_affine = False
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
self.main = nn.Sequential(
|
||||||
|
nn.Conv2d(dim_in,
|
||||||
|
dim_out,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False), nn.InstanceNorm2d(dim_out,
|
||||||
|
affine=use_affine),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(dim_out,
|
||||||
|
dim_out,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False), nn.InstanceNorm2d(dim_out,
|
||||||
|
affine=use_affine))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.main(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GANLoss(nn.Module):
|
||||||
|
"""Define different GAN objectives.
|
||||||
|
The GANLoss class abstracts away the need to create the target label tensor
|
||||||
|
that has the same size as the input.
|
||||||
|
"""
|
||||||
|
def __init__(self, gan_mode, target_real_label=0.0, target_fake_label=1.0):
|
||||||
|
""" Initialize the GANLoss class.
|
||||||
|
Parameters:
|
||||||
|
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
||||||
|
target_real_label (bool) - - label for a real image
|
||||||
|
target_fake_label (bool) - - label of a fake image
|
||||||
|
Note: Do not use sigmoid as the last layer of Discriminator.
|
||||||
|
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
||||||
|
"""
|
||||||
|
super(GANLoss, self).__init__()
|
||||||
|
self.register_buffer('real_label', torch.tensor(target_real_label))
|
||||||
|
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
||||||
|
self.gan_mode = gan_mode
|
||||||
|
if gan_mode == 'lsgan':
|
||||||
|
self.loss = nn.MSELoss()
|
||||||
|
elif gan_mode == 'vanilla':
|
||||||
|
self.loss = nn.BCEWithLogitsLoss()
|
||||||
|
elif gan_mode in ['wgangp']:
|
||||||
|
self.loss = None
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
||||||
|
|
||||||
|
def get_target_tensor(self, label, target_is_real):
|
||||||
|
"""Create label tensors with the same size as the input.
|
||||||
|
Parameters:
|
||||||
|
prediction (tensor) - - tpyically the prediction from a discriminator
|
||||||
|
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
||||||
|
Returns:
|
||||||
|
A label tensor filled with ground truth label, and with the size of the input
|
||||||
|
"""
|
||||||
|
if target_is_real:
|
||||||
|
real_label = self.real_label.expand(label.size(0), 1)
|
||||||
|
target_tensor = torch.cat([label, real_label], dim=-1)
|
||||||
|
else:
|
||||||
|
fake_label = self.fake_label.expand(label.size(0), 1)
|
||||||
|
target_tensor = torch.cat([label, fake_label], dim=-1)
|
||||||
|
return target_tensor
|
||||||
|
|
||||||
|
def __call__(self, prediction, label, target_is_real):
|
||||||
|
"""Calculate loss given Discriminator's output and grount truth labels.
|
||||||
|
Parameters:
|
||||||
|
prediction (tensor) - - tpyically the prediction output from a discriminator
|
||||||
|
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
||||||
|
Returns:
|
||||||
|
the calculated loss.
|
||||||
|
"""
|
||||||
|
if self.gan_mode in ['lsgan', 'vanilla']:
|
||||||
|
target_tensor = self.get_target_tensor(label, target_is_real)
|
||||||
|
loss = self.loss(prediction, target_tensor)
|
||||||
|
elif self.gan_mode == 'wgangp':
|
||||||
|
if target_is_real:
|
||||||
|
loss = -prediction.mean()
|
||||||
|
else:
|
||||||
|
loss = prediction.mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(optimizer, opt):
|
||||||
|
"""Return a learning rate scheduler
|
||||||
|
Parameters:
|
||||||
|
optimizer -- the optimizer of the network
|
||||||
|
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
||||||
|
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
||||||
|
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
||||||
|
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
||||||
|
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
||||||
|
See https://pytorch.org/docs/stable/optim.html for more details.
|
||||||
|
"""
|
||||||
|
if opt.lr_policy == 'linear':
|
||||||
|
|
||||||
|
def lambda_rule(epoch):
|
||||||
|
lr_l = 1.0 - max(0, epoch + opt.epoch_count -
|
||||||
|
opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
||||||
|
return lr_l
|
||||||
|
|
||||||
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
||||||
|
elif opt.lr_policy == 'step':
|
||||||
|
scheduler = lr_scheduler.StepLR(optimizer,
|
||||||
|
step_size=opt.lr_decay_iters,
|
||||||
|
gamma=0.1)
|
||||||
|
elif opt.lr_policy == 'plateau':
|
||||||
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||||
|
mode='min',
|
||||||
|
factor=0.2,
|
||||||
|
threshold=0.01,
|
||||||
|
patience=5)
|
||||||
|
elif opt.lr_policy == 'cosine':
|
||||||
|
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
|
||||||
|
T_max=opt.n_epochs,
|
||||||
|
eta_min=0)
|
||||||
|
else:
|
||||||
|
return NotImplementedError(
|
||||||
|
'learning rate policy [%s] is not implemented', opt.lr_policy)
|
||||||
|
return scheduler
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
def l2normalize(v, eps=1e-12):
|
||||||
|
return v / (v.norm() + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralNorm(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "weight"
|
||||||
|
#print(self.name)
|
||||||
|
self.power_iterations = 1
|
||||||
|
|
||||||
|
def compute_weight(self, module):
|
||||||
|
u = getattr(module, self.name + "_u")
|
||||||
|
v = getattr(module, self.name + "_v")
|
||||||
|
w = getattr(module, self.name + "_bar")
|
||||||
|
|
||||||
|
height = w.data.shape[0]
|
||||||
|
for _ in range(self.power_iterations):
|
||||||
|
v.data = l2normalize(
|
||||||
|
torch.mv(torch.t(w.view(height, -1).data), u.data))
|
||||||
|
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
||||||
|
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
||||||
|
sigma = u.dot(w.view(height, -1).mv(v))
|
||||||
|
return w / sigma.expand_as(w)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(module):
|
||||||
|
name = "weight"
|
||||||
|
fn = SpectralNorm()
|
||||||
|
|
||||||
|
try:
|
||||||
|
u = getattr(module, name + "_u")
|
||||||
|
v = getattr(module, name + "_v")
|
||||||
|
w = getattr(module, name + "_bar")
|
||||||
|
except AttributeError:
|
||||||
|
w = getattr(module, name)
|
||||||
|
height = w.data.shape[0]
|
||||||
|
width = w.view(height, -1).data.shape[1]
|
||||||
|
u = Parameter(w.data.new(height).normal_(0, 1),
|
||||||
|
requires_grad=False)
|
||||||
|
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
||||||
|
w_bar = Parameter(w.data)
|
||||||
|
|
||||||
|
#del module._parameters[name]
|
||||||
|
|
||||||
|
module.register_parameter(name + "_u", u)
|
||||||
|
module.register_parameter(name + "_v", v)
|
||||||
|
module.register_parameter(name + "_bar", w_bar)
|
||||||
|
|
||||||
|
# remove w from parameter list
|
||||||
|
del module._parameters[name]
|
||||||
|
|
||||||
|
setattr(module, name, fn.compute_weight(module))
|
||||||
|
|
||||||
|
# recompute weight before every forward()
|
||||||
|
module.register_forward_pre_hook(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
def remove(self, module):
|
||||||
|
weight = self.compute_weight(module)
|
||||||
|
delattr(module, self.name)
|
||||||
|
del module._parameters[self.name + '_u']
|
||||||
|
del module._parameters[self.name + '_v']
|
||||||
|
del module._parameters[self.name + '_bar']
|
||||||
|
module.register_parameter(self.name, Parameter(weight.data))
|
||||||
|
|
||||||
|
def __call__(self, module, inputs):
|
||||||
|
setattr(module, self.name, self.compute_weight(module))
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_norm(module):
|
||||||
|
SpectralNorm.apply(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def remove_spectral_norm(module):
|
||||||
|
name = 'weight'
|
||||||
|
for k, hook in module._forward_pre_hooks.items():
|
||||||
|
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||||
|
hook.remove(module)
|
||||||
|
del module._forward_pre_hooks[k]
|
||||||
|
return module
|
||||||
|
|
||||||
|
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
||||||
|
name, module))
|
||||||
|
|
@ -6,13 +6,23 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import scipy.io as scio
|
import scipy.io as scio
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .base import TrainBase
|
from .base import TrainBase
|
||||||
from model.optimization import BertAdam
|
from model.optimization import BertAdam
|
||||||
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity
|
# from model.GAN import Discriminator, Generator, LabelEncoder, GANLoss
|
||||||
|
from utils import get_args, calc_neighbor, cosine_similarity, euclidean_similarity,find_indices
|
||||||
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
from utils.calc_utils import calc_map_k_matrix as calc_map_k
|
||||||
from dataset.dataloader import dataloader
|
from dataset.dataloader import dataloader
|
||||||
|
import open_clip
|
||||||
|
# from transformers import BertModel
|
||||||
|
|
||||||
|
def clamp(delta, clean_imgs):
|
||||||
|
|
||||||
|
clamp_imgs = (delta.data + clean_imgs.data).clamp(0, 1)
|
||||||
|
clamp_delta = clamp_imgs - clean_imgs.data
|
||||||
|
|
||||||
|
return clamp_delta
|
||||||
|
|
||||||
class Trainer(TrainBase):
|
class Trainer(TrainBase):
|
||||||
|
|
||||||
|
|
@ -21,32 +31,62 @@ class Trainer(TrainBase):
|
||||||
args = get_args()
|
args = get_args()
|
||||||
super(Trainer, self).__init__(args, rank)
|
super(Trainer, self).__init__(args, rank)
|
||||||
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
self.logger.info("dataset len: {}".format(len(self.train_loader.dataset)))
|
||||||
self.run()
|
image_representation, text_representation=self.generate_mapping()
|
||||||
|
self.image_representation=image_representation
|
||||||
|
self.text_representation=text_representation
|
||||||
|
# self.run()
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
self.logger.info("init model.")
|
self.logger.info("init model.")
|
||||||
linear = False
|
# self.generator=Generator()
|
||||||
if self.args.hash_layer == "linear":
|
# linear = False
|
||||||
linear = True
|
# if self.args.hash_layer == "linear":
|
||||||
|
# linear = True
|
||||||
self.logger.info("ViT+GPT!")
|
# self.bert=BertModel.from_pretrained("bert-base-cased", output_hidden_states=True).to(self.rank)
|
||||||
HashModel = DCMHT
|
# self.bert.eval()
|
||||||
self.model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
# self.logger.info("ViT+GPT!")
|
||||||
writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
# HashModel = DCMHT
|
||||||
if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
|
# if self.args.victim_model == 'JDSH':
|
||||||
self.logger.info("load pretrained model.")
|
# from model.JDSH import TxtNet, ImgNet
|
||||||
self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
# # self.img_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
||||||
|
# # writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
||||||
|
# self.img_model=ImgNet(code_len=self.args.output_dim).to(self.rank)
|
||||||
|
# self.txt_model=TxtNet(code_len=self.args.output_dim, txt_feat_len=self.args.txt_dim).to(self.rank)
|
||||||
|
# path=os.path.join(self.args.checkpoints,self.args.victim_model+'/'+str(self.args.output_dim)+'_'+self.args.dataset+'latest.pth')
|
||||||
|
# checkpoint=torch.load(path)
|
||||||
|
# self.img_model.load_state_dict(torch.load(checkpoint['ImgNet'], map_location=f"cuda:{self.rank}"))
|
||||||
|
# self.txt_model.load_state_dict(torch.load(checkpoint['TxtNet'], map_location=f"cuda:{self.rank}"))
|
||||||
|
# self.img_model.eval()
|
||||||
|
# self.txt_model.eval()
|
||||||
|
# elif self.args.victim_model == 'DJSRH':
|
||||||
|
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
||||||
|
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
||||||
|
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
||||||
|
# elif self.args.victim_model == 'SSAH':
|
||||||
|
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
||||||
|
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
||||||
|
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
||||||
|
# elif self.args.victim_model == 'DCHUC':
|
||||||
|
# self.victim_model = HashModel(outputDim=self.args.output_dim, clipPath=self.args.clip_path,
|
||||||
|
# writer=self.writer, logger=self.logger, is_train=self.args.is_train, linear=linear).to(self.rank)
|
||||||
|
# self.victim_model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
||||||
|
|
||||||
|
# if self.args.pretrained != "" and os.path.exists(self.args.pretrained):
|
||||||
|
# self.logger.info("load pretrained model.")
|
||||||
|
# self.model.load_state_dict(torch.load(self.args.pretrained, map_location=f"cuda:{self.rank}"))
|
||||||
|
model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', device=self.device)
|
||||||
|
self.model= model_clip
|
||||||
|
self.model.eval()
|
||||||
self.model.float()
|
self.model.float()
|
||||||
self.optimizer = BertAdam([
|
# self.optimizer = BertAdam([
|
||||||
{'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
# {'params': self.model.clip.parameters(), 'lr': self.args.clip_lr},
|
||||||
{'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
# {'params': self.model.image_hash.parameters(), 'lr': self.args.lr},
|
||||||
{'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
|
# {'params': self.model.text_hash.parameters(), 'lr': self.args.lr}
|
||||||
], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
|
# ], lr=self.args.lr, warmup=self.args.warmup_proportion, schedule='warmup_cosine',
|
||||||
b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
|
# b1=0.9, b2=0.98, e=1e-6, t_total=len(self.train_loader) * self.args.epochs,
|
||||||
weight_decay=self.args.weight_decay, max_grad_norm=1.0)
|
# weight_decay=self.args.weight_decay, max_grad_norm=1.0)
|
||||||
|
|
||||||
print(self.model)
|
# print(self.model)
|
||||||
|
|
||||||
def _init_dataset(self):
|
def _init_dataset(self):
|
||||||
self.logger.info("init dataset.")
|
self.logger.info("init dataset.")
|
||||||
|
|
@ -89,12 +129,62 @@ class Trainer(TrainBase):
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
|
def generate_mapping(self):
|
||||||
|
text_train=[]
|
||||||
|
label_train=[]
|
||||||
|
# image_train=[]
|
||||||
|
# self.change_state(mode="valid")
|
||||||
|
for image, text, label, index in self.train_loader:
|
||||||
|
# image=image.to(self.device, non_blocking=True)
|
||||||
|
text=text.to(self.device, non_blocking=True)
|
||||||
|
temp_text=self.model.encode_text(text)
|
||||||
|
# temp_image=self.model.encode_image(image)
|
||||||
|
# image_train.append(temp_image.cpu().detach().numpy())
|
||||||
|
text_train.append(temp_text.cpu().detach().numpy())
|
||||||
|
label_train.append(label.detach().numpy())
|
||||||
|
text_train=np.concatenate(text_train, axis=0)
|
||||||
|
# image_train=np.concatenate(image_train, axis=0)
|
||||||
|
label_train=np.concatenate(label_train, axis=0)
|
||||||
|
label_unipue=np.unique(label_train,axis=0)
|
||||||
|
# image_centroids =np.stack([image_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
|
text_centroids =np.stack([text_train[find_indices(label_train,label_unipue[i])].mean(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
|
text_var=np.stack([text_train[find_indices(label_train,label_unipue[i])].var(axis=0) for i in range(len(label_unipue))], axis=0)
|
||||||
|
|
||||||
|
text_representation = {}
|
||||||
|
text_var_representation = {}
|
||||||
|
for i, centroid in enumerate(label_unipue):
|
||||||
|
text_representation[centroid.tobytes()] = text_centroids[i]
|
||||||
|
text_var_representation[centroid.tobytes()]= text_var[i]
|
||||||
|
return text_representation, text_var_representation
|
||||||
|
|
||||||
|
def target_adv(self, image, positive, negative,
|
||||||
|
epsilon=0.03125, alpha=3/255, num_iter=100):
|
||||||
|
|
||||||
|
delta = torch.zeros_like(image,requires_grad=True)
|
||||||
|
# clean_output = self.model.encode_image(image)
|
||||||
|
one=torch.zeros_like(positive)
|
||||||
|
alienation_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
||||||
|
for i in range(num_iter):
|
||||||
|
self.model.zero_grad()
|
||||||
|
anchor=self.model.encode_image(image+delta)
|
||||||
|
loss=alienation_loss(anchor, positive, negative)
|
||||||
|
|
||||||
|
|
||||||
|
loss.backward(retain_graph=True)
|
||||||
|
delta.data = delta - alpha * delta.grad.detach().sign()
|
||||||
|
delta.data =clamp(delta, image).clamp(-epsilon, epsilon)
|
||||||
|
delta.grad.zero_()
|
||||||
|
|
||||||
|
return delta.detach()
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
def train_epoch(self, epoch):
|
||||||
self.change_state(mode="train")
|
self.change_state(mode="valid")
|
||||||
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
self.logger.info(">>>>>> epochs: %d/%d"%(epoch, self.args.epochs))
|
||||||
all_loss = 0
|
all_loss = 0
|
||||||
times = 0
|
times = 0
|
||||||
|
adv_images=[]
|
||||||
|
adv_labels=[]
|
||||||
|
texts=[]
|
||||||
for image, text, label, index in self.train_loader:
|
for image, text, label, index in self.train_loader:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
times += 1
|
times += 1
|
||||||
|
|
@ -102,27 +192,29 @@ class Trainer(TrainBase):
|
||||||
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
if self.args.dataset not in ["flickr25k", "coco", "nuswide"]:
|
||||||
label = torch.ones([image.shape[0]], dtype=torch.int)
|
label = torch.ones([image.shape[0]], dtype=torch.int)
|
||||||
label = label.diag()
|
label = label.diag()
|
||||||
# print(text.dtype)
|
|
||||||
# text.float()
|
|
||||||
# label.float()
|
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
# print("text shape:", text.shape)
|
|
||||||
index = index.numpy()
|
index = index.numpy()
|
||||||
|
image_anchor=self.image_representation(label.detach().cpu().numpy())
|
||||||
|
text_anchor=self.text_representation(label.detach().cpu().numpy())
|
||||||
|
negetive_code=np.concatenate([image_anchor,text_anchor],axis=0).mean(axis=0)
|
||||||
|
target_label=label.flip(dims=[0])
|
||||||
|
target_image_anchor=self.image_representation(target_label.detach().cpu().numpy())
|
||||||
|
target_text_anchor=self.text_representation(target_label.detach().cpu().numpy())
|
||||||
|
positive_code=np.concatenate([target_image_anchor,target_text_anchor],axis=0).mean(axis=0)
|
||||||
|
# print("text shape:", text.shape)
|
||||||
|
# index = index.numpy()
|
||||||
# print(text.shape)
|
# print(text.shape)
|
||||||
hash_img, hash_text = self.model(image, text)
|
delta=self.target_adv(image,torch.from_numpy(positive_code).to(self.rank, non_blocking=True),
|
||||||
if self.args.hash_layer == "select":
|
torch.from_numpy(negetive_code).to(self.rank, non_blocking=True))
|
||||||
hash_img = torch.cat(hash_img, dim=-1) if isinstance(hash_img, list) else hash_img.view(hash_img.shape[0], -1)
|
adv_image=delta+image
|
||||||
hash_text = torch.cat(hash_text, dim=-1)if isinstance(hash_text, list) else hash_text.view(hash_text.shape[0], -1)
|
adv_images.append(adv_image)
|
||||||
loss = self.compute_loss(hash_img, hash_text, label, epoch, times)
|
adv_labels.append(target_label)
|
||||||
all_loss += loss
|
texts.append(text)
|
||||||
|
# self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
||||||
|
return adv_images, texts, adv_labels
|
||||||
|
|
||||||
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
self.logger.info(f">>>>>> [{epoch}/{self.args.epochs}] loss: {all_loss.data / (len(self.train_loader))}, lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self.logger.info("Start train.")
|
self.logger.info("Start train.")
|
||||||
|
|
@ -204,78 +296,40 @@ class Trainer(TrainBase):
|
||||||
image = image.to(self.rank, non_blocking=True)
|
image = image.to(self.rank, non_blocking=True)
|
||||||
text = text.to(self.rank, non_blocking=True)
|
text = text.to(self.rank, non_blocking=True)
|
||||||
index = index.numpy()
|
index = index.numpy()
|
||||||
image_hash = self.model.encode_image(image)
|
image_hash=self.img_model(image)
|
||||||
image_hash = self.make_hash_code(image_hash)
|
text_feat=self.bert(text)[0]
|
||||||
text_hash = self.model.encode_text(text)
|
text_hash=self.txt_model(text_feat)
|
||||||
text_hash = self.make_hash_code(text_hash)
|
|
||||||
# image_hash.to(self.rank)
|
|
||||||
# text_hash.to(self.rank)
|
|
||||||
img_buffer[index, :] = image_hash.data
|
img_buffer[index, :] = image_hash.data
|
||||||
text_buffer[index, :] = text_hash.data
|
text_buffer[index, :] = text_hash.data
|
||||||
|
|
||||||
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
|
|
||||||
def our_loss(self, image, text, label, epoch, times):
|
def get_adv_code(self, adv_data_list,text_list):
|
||||||
loss = 0
|
|
||||||
|
|
||||||
label_sim = calc_neighbor(label, label)
|
img_buffer = torch.empty(len(adv_data_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
if image.is_cuda:
|
text_buffer = torch.empty(len(text_list), self.args.output_dim, dtype=torch.float).to(self.rank)
|
||||||
label_sim = label_sim.to(image.device)
|
|
||||||
intra_similarity, intra_positive_loss, intra_negative_loss = self.similarity_loss(image, text, label_sim)
|
|
||||||
inter_similarity_i, inter_positive_loss_i, inter_negative_loss_i = self.similarity_loss(image, image, label_sim)
|
|
||||||
inter_similarity_t, inter_positive_loss_t, inter_negative_loss_t = self.similarity_loss(text, text, label_sim)
|
|
||||||
|
|
||||||
intra_similarity_loss = (intra_positive_loss + intra_negative_loss) if self.args.similarity_function == "euclidean" else (intra_positive_loss + intra_negative_loss)
|
for i in tqdm(range(len(adv_data_list))):
|
||||||
inter_similarity_loss = inter_positive_loss_t + inter_positive_loss_i + (inter_negative_loss_i + inter_negative_loss_t) if self.args.similarity_function == "euclidean" else inter_positive_loss_t + inter_positive_loss_i + inter_negative_loss_i + inter_negative_loss_t
|
image = adv_data_list[i].to(self.rank, non_blocking=True)
|
||||||
similarity_loss = inter_similarity_loss + intra_similarity_loss
|
text = text_list[i].to(self.rank, non_blocking=True)
|
||||||
|
# index = index.numpy()
|
||||||
|
image_hash=self.img_model(image)
|
||||||
|
text_feat=self.bert(text)[0]
|
||||||
|
text_hash=self.txt_model(text_feat)
|
||||||
|
# text_hash = self.make_hash_code(text_hash)
|
||||||
|
# image_hash.to(self.rank)
|
||||||
|
# text_hash.to(self.rank)
|
||||||
|
img_buffer[i, :] = image_hash.data
|
||||||
|
text_buffer[i, :] = text_hash.data
|
||||||
|
|
||||||
# if self.writer is not None:
|
return img_buffer, text_buffer# img_buffer.to(self.rank), text_buffer.to(self.rank)
|
||||||
# self.writer.add_scalar("intra similarity max", intra_similarity.max(), self.global_step)
|
|
||||||
# self.writer.add_scalar("intra similarity min", intra_similarity.min(), self.global_step)
|
|
||||||
# self.writer.add_scalar("intra positive loss", intra_positive_loss.data, self.global_step)
|
|
||||||
# self.writer.add_scalar("intra negative loss", intra_negative_loss.data, self.global_step)
|
|
||||||
|
|
||||||
# self.writer.add_scalar("inter image similarity max", inter_similarity_i.max(), self.global_step)
|
|
||||||
# self.writer.add_scalar("inter image similarity min", inter_similarity_i.min(), self.global_step)
|
|
||||||
# self.writer.add_scalar("inter image positive loss", inter_positive_loss_i.data, self.global_step)
|
|
||||||
# self.writer.add_scalar("inter image negative loss", inter_negative_loss_i.data, self.global_step)
|
|
||||||
|
|
||||||
# self.writer.add_scalar("inter text similarity max", inter_similarity_t.max(), self.global_step)
|
def valid_attack(self,adv_images, texts, adv_labels):
|
||||||
# self.writer.add_scalar("inter text similarity min", inter_similarity_t.min(), self.global_step)
|
save_dir = os.path.join(self.args.save_dir, "adv_PR_cruve")
|
||||||
# self.writer.add_scalar("inter text positive loss", inter_positive_loss_t.data, self.global_step)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
# self.writer.add_scalar("inter text negative loss", inter_negative_loss_t.data, self.global_step)
|
|
||||||
|
|
||||||
# self.writer.add_scalar("intra similarity loss", intra_similarity_loss.data, self.global_step)
|
|
||||||
# self.writer.add_scalar("inter similarity loss", inter_similarity_loss.data, self.global_step)
|
|
||||||
# self.writer.add_scalar("similarity loss", similarity_loss.data, self.global_step)
|
|
||||||
|
|
||||||
if self.args.hash_layer != "select":
|
|
||||||
quantization_loss = (self.hash_loss(image) + self.hash_loss(text)) / 2
|
|
||||||
loss = similarity_loss + quantization_loss
|
|
||||||
if self.global_step % self.args.display_step == 0:
|
|
||||||
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
|
|
||||||
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
|
|
||||||
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
|
|
||||||
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
|
|
||||||
f"QUATIZATION LOSS, {quantization_loss.data}, "\
|
|
||||||
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
|
||||||
else:
|
|
||||||
loss = similarity_loss # + self.args.qua_gamma * (image_quantization_loss + text_quantization_loss)
|
|
||||||
if self.global_step % self.args.display_step == 0:
|
|
||||||
self.logger.info(f">>>>>> Display >>>>>> [{epoch}/{self.args.epochs}], [{times}/{len(self.train_loader)}]: all loss: {loss.data}, "\
|
|
||||||
f"SIMILARITY LOSS, Intra, positive: {intra_positive_loss.data}, negitave: {intra_negative_loss.data}, sum: {intra_similarity_loss.data}, " \
|
|
||||||
f"Inter, image positive: {inter_positive_loss_i.data}, image negitave: {inter_negative_loss_i.data}, "\
|
|
||||||
f"text positive: {inter_positive_loss_t.data}, text negitave: {inter_negative_loss_t.data}, sum: {inter_similarity_loss.data}, "\
|
|
||||||
# f"QUATIZATION LOSS, image: {image_quantization_loss.data}, text: {text_quantization_loss.data}, "\
|
|
||||||
f"lr: {'-'.join([str('%.9f'%itm) for itm in sorted(list(set(self.optimizer.get_lr())))])}")
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def compute_loss(self, image, text, label, epoch, times):
|
|
||||||
|
|
||||||
loss = self.our_loss(image, text, label, epoch, times)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def test(self, mode_name="i2t"):
|
def test(self, mode_name="i2t"):
|
||||||
if self.args.pretrained == "":
|
if self.args.pretrained == "":
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ def get_args():
|
||||||
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
|
# parser.add_argument("--test-index-file", type=str, default="./data/test/index.mat")
|
||||||
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
|
# parser.add_argument("--test-caption-file", type=str, default="./data/test/captions.mat")
|
||||||
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
|
# parser.add_argument("--test-label-file", type=str, default="./data/test/label.mat")
|
||||||
|
parser.add_argument("--txt-dim", type=int, default=1024)
|
||||||
parser.add_argument("--output-dim", type=int, default=64)
|
parser.add_argument("--output-dim", type=int, default=64)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--max-words", type=int, default=32)
|
parser.add_argument("--max-words", type=int, default=32)
|
||||||
|
|
|
||||||
|
|
@ -227,3 +227,18 @@ class CrossEn_mean(nn.Module):
|
||||||
# print(logpt.max())
|
# print(logpt.max())
|
||||||
sim_loss = sim_matrix.mean()
|
sim_loss = sim_matrix.mean()
|
||||||
return sim_loss
|
return sim_loss
|
||||||
|
|
||||||
|
def find_indices(array, b):
|
||||||
|
# Create a boolean mask where the first dimension of the array equals b
|
||||||
|
mask = np.all(array[:, :] == b, axis=1)
|
||||||
|
|
||||||
|
# Find the indices where the mask is True
|
||||||
|
indices = np.argwhere(mask).flatten()
|
||||||
|
|
||||||
|
return indices
|
||||||
|
|
||||||
|
# def get_var(data_list):
|
||||||
|
# mean=data_list.mean(axis=0)
|
||||||
|
# sq_var=torch.zeros_like(data_list[0])
|
||||||
|
# for i in data_list:
|
||||||
|
# sq_var+=
|
||||||
Loading…
Reference in New Issue