add target
This commit is contained in:
parent
06e2774cd7
commit
43c7bd0fa3
216
model.py
216
model.py
|
|
@ -35,26 +35,53 @@ def get_stylegan_generator(cfg):
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
|
|
||||||
# class GanAttack(nn.Module):
|
class LabelEncoder(nn.Module):
|
||||||
|
def __init__(self, nf=307):
|
||||||
|
super(LabelEncoder, self).__init__()
|
||||||
|
self.nf = nf
|
||||||
|
curr_dim = nf
|
||||||
|
self.size = 64
|
||||||
|
|
||||||
# def get_stylegan_inverter(cfg):
|
self.fc = nn.Sequential(
|
||||||
|
# nn.Linear(512, 512), nn.ReLU(True),
|
||||||
|
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
|
||||||
|
|
||||||
# # ensure_checkpoint_exists(ckpt_path)
|
transform = []
|
||||||
# path=cfg.paths.inverter_cfg
|
for i in range(4):
|
||||||
# ckpt = torch.load(path, map_location='cuda:0')
|
transform += [
|
||||||
# opts = ckpt['opts']
|
nn.ConvTranspose2d(curr_dim,
|
||||||
# opts['checkpoint_path'] = path
|
curr_dim // 2,
|
||||||
# if 'learn_in_w' not in opts:
|
kernel_size=4,
|
||||||
# opts['learn_in_w'] = False
|
stride=2,
|
||||||
# if 'output_size' not in opts:
|
padding=1,
|
||||||
# opts['output_size'] = 1024
|
bias=False),
|
||||||
|
# nn.Upsample(scale_factor=(2, 2)),
|
||||||
# net = pSp(Namespace(**opts))
|
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
|
||||||
# net.eval()
|
nn.InstanceNorm2d(curr_dim // 2, affine=False),
|
||||||
# net.cuda()
|
nn.ReLU(inplace=True)
|
||||||
# return net
|
]
|
||||||
|
curr_dim = curr_dim // 2
|
||||||
|
|
||||||
|
transform += [
|
||||||
|
nn.Conv2d(curr_dim,
|
||||||
|
3,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False)
|
||||||
|
]
|
||||||
|
self.transform = nn.Sequential(*transform)
|
||||||
|
|
||||||
|
def forward(self, label_feature):
|
||||||
|
label_feature = self.fc(label_feature)
|
||||||
|
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
|
||||||
|
label_feature = self.transform(label_feature)
|
||||||
|
|
||||||
|
# mixed_feature = label_feature + image
|
||||||
|
# mixed_feature = torch.cat((label_feature, image), dim=1)
|
||||||
|
return label_feature
|
||||||
|
|
||||||
class GanAttack(nn.Module):
|
class TargetedGanAttack(nn.Module):
|
||||||
def __init__(self, cfg,prompt):
|
def __init__(self, cfg,prompt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -62,40 +89,70 @@ class GanAttack(nn.Module):
|
||||||
# self.generator.eval()
|
# self.generator.eval()
|
||||||
# self.inverter=inverter
|
# self.inverter=inverter
|
||||||
# self.images_resize=images_resize
|
# self.images_resize=images_resize
|
||||||
self.generator=get_stylegan_generator(cfg)
|
# self.generator=get_stylegan_generator(cfg)
|
||||||
self.prompt=prompt
|
self.prompt=prompt
|
||||||
text_len=self.prompt.shape[1]
|
text_len=self.prompt.shape[1]
|
||||||
self.mlp=nn.Sequential(
|
# self.mlp=nn.Sequential(
|
||||||
nn.Linear(text_len+512, 4096),
|
# nn.Linear(text_len+512, 4096),
|
||||||
nn.ReLU(inplace=True),
|
# nn.ReLU(inplace=True),
|
||||||
nn.Linear(4096, 512)
|
# nn.Linear(4096, 512)
|
||||||
)
|
# )
|
||||||
self.noise_mlp=nn.Sequential(
|
self.mlp=nn.LazyLinear(4096)
|
||||||
nn.Linear(10,1024),
|
self.label_encoder=LabelEncoder()
|
||||||
nn.ReLU(inplace=True),
|
encoder_lis = [
|
||||||
nn.Linear(1024,512)
|
# MNIST:1*28*28
|
||||||
)
|
nn.Conv2d(29, 32, kernel_size=3, stride=1, padding=0, bias=True),
|
||||||
self.noises=self.generator.make_noise()
|
nn.InstanceNorm2d(32),
|
||||||
for i,j in enumerate(self.noises):
|
nn.ReLU(),
|
||||||
if i>9:
|
# 8*26*26
|
||||||
j.detach()
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
|
||||||
|
nn.InstanceNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
# 16*12*12
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
|
||||||
|
nn.InstanceNorm2d(128),
|
||||||
|
nn.ReLU(),
|
||||||
|
# 32*5*5
|
||||||
|
]
|
||||||
|
bottle_neck_lis = [ResnetBlock(128),
|
||||||
|
ResnetBlock(128),
|
||||||
|
ResnetBlock(128),
|
||||||
|
ResnetBlock(128),]
|
||||||
|
decoder_lis = [
|
||||||
|
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
|
||||||
|
nn.InstanceNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
|
# state size. 16 x 11 x 11
|
||||||
|
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
|
||||||
|
nn.InstanceNorm2d(32),
|
||||||
|
nn.ReLU(),
|
||||||
|
# state size. 8 x 23 x 23
|
||||||
|
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
|
||||||
|
nn.Tanh()
|
||||||
|
# state size. image_nc x 28 x 28
|
||||||
|
]
|
||||||
|
self.encoder = nn.Sequential(*encoder_lis)
|
||||||
|
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
|
||||||
|
self.decoder = nn.Sequential(*decoder_lis)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, img,detailcode):
|
def forward(self, detailcode,noise,label):
|
||||||
batch_size=img.shape[0]
|
label_feat=self.label_encoder(label)
|
||||||
|
batch_size=detailcode.shape[0]
|
||||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||||
# print(prompt.shape)
|
# print(prompt.shape)
|
||||||
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
||||||
x_prompt=self.mlp(x_prompt)
|
x_prompt=self.mlp(x_prompt)
|
||||||
x=x_prompt+x
|
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
|
||||||
adv=self.noise_mlp(self.noises)
|
# noise_temp.append(noise[7].squeeze())
|
||||||
self.noises=adv+self.noises
|
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
|
||||||
result_images, _ =self.generator([detailcode],input_is_latent=True,
|
# x=x_prompt+x
|
||||||
randomize_noise=False,noises=self.noises)
|
mixed_feature=self.encoder(mixed_feature)
|
||||||
|
mixed_feature=self.bottle_neck(mixed_feature)
|
||||||
|
mixed_feature=self.decoder(mixed_feature)
|
||||||
|
|
||||||
|
|
||||||
return result_images,x
|
return mixed_feature
|
||||||
|
|
||||||
class CLIPLoss(torch.nn.Module):
|
class CLIPLoss(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -113,33 +170,56 @@ class CLIPLoss(torch.nn.Module):
|
||||||
return similarity
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
# class VggLoss(torch.nn.Module):
|
class ResnetBlock(nn.Module):
|
||||||
# def __init__(self):
|
def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
|
||||||
# super(VggLoss, self).__init__()
|
super(ResnetBlock, self).__init__()
|
||||||
# self.model=models.vgg11(pretrained=True)
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
||||||
# self.model.features=nn.Sequential()
|
|
||||||
|
|
||||||
# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
|
||||||
# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
|
||||||
|
|
||||||
# def forward(self, image1, image2):
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||||
# # image=normalize(image)
|
conv_block = []
|
||||||
# with torch.no_grad:
|
p = 0
|
||||||
# feature1=self.model(image1)
|
if padding_type == 'reflect':
|
||||||
# feature2=self.model(image2)
|
conv_block += [nn.ReflectionPad2d(1)]
|
||||||
# feature1=torch.flatten(feature1)
|
elif padding_type == 'replicate':
|
||||||
# feature2=torch.flatten(feature2)
|
conv_block += [nn.ReplicationPad2d(1)]
|
||||||
# similarity = F.cosine_similarity(feature1,feature2)
|
elif padding_type == 'zero':
|
||||||
# return similarity
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
|
||||||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||||
|
norm_layer(dim),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
if use_dropout:
|
||||||
|
conv_block += [nn.Dropout(0.5)]
|
||||||
|
|
||||||
|
p = 0
|
||||||
|
if padding_type == 'reflect':
|
||||||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||||||
|
elif padding_type == 'replicate':
|
||||||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||||||
|
elif padding_type == 'zero':
|
||||||
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
|
||||||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||||
|
norm_layer(dim)]
|
||||||
|
|
||||||
|
return nn.Sequential(*conv_block)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x + self.conv_block(x)
|
||||||
|
return out
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
# @hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def test(cfg):
|
# def test(cfg):
|
||||||
prompt=torch.randn([1,1024]).to(device)
|
# prompt=torch.randn([1,1024]).to(device)
|
||||||
model=GanAttack(cfg,prompt).to(device)
|
# model=GanAttack(cfg,prompt).to(device)
|
||||||
data=torch.randn([2,3,256,256]).to(device)
|
# data=torch.randn([2,3,256,256]).to(device)
|
||||||
result_images,x=model(data)
|
# result_images,x=model(data)
|
||||||
print(result_images.shape)
|
# print(result_images.shape)
|
||||||
print(x.shape)
|
# print(x.shape)
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
test()
|
# test()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,152 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import models
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
|
from utils import get_model,set_requires_grad,unnormalize
|
||||||
|
from stylegan.model import Generator
|
||||||
|
from model import CLIPLoss,TargetedGanAttack
|
||||||
|
import lpips
|
||||||
|
from prompt import get_prompt
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import time
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_stylegan_generator(cfg):
|
||||||
|
|
||||||
|
# ensure_checkpoint_exists(ckpt_path)
|
||||||
|
ckpt_path=cfg.paths.stylegan
|
||||||
|
g_ema = Generator(1024, 512, 8)
|
||||||
|
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
||||||
|
g_ema.eval()
|
||||||
|
g_ema = g_ema.cuda()
|
||||||
|
mean_latent = g_ema.mean_latent(4096)
|
||||||
|
|
||||||
|
return g_ema, mean_latent
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
model=get_model(cfg)
|
||||||
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
||||||
|
|
||||||
|
classifier=get_model(cfg)
|
||||||
|
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||||
|
classifier.eval()
|
||||||
|
g_ema, _=get_stylegan_generator(cfg)
|
||||||
|
prompt=get_prompt(cfg)
|
||||||
|
|
||||||
|
# net=get_stylegan_inverter(cfg)
|
||||||
|
|
||||||
|
model=TargetedGanAttack(cfg,prompt).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
num_epochs = cfg.optim.num_epochs
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
max_loss=nn.MarginRankingLoss(0.1)
|
||||||
|
clip_loss=CLIPLoss().to(device)
|
||||||
|
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
||||||
|
# vgg_loss=VggLoss().to(device)
|
||||||
|
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||||
|
# set_requires_grad(model.mlp.parameters())
|
||||||
|
# for p in (model.mlp.parameters()):
|
||||||
|
# p.requires_grad =True
|
||||||
|
# for p in (model.noise_mlp.parameters()):
|
||||||
|
# p.requires_grad =True
|
||||||
|
model.train()
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
# base_code=base_code.to(device)
|
||||||
|
detail_code=detail_code.to(device)
|
||||||
|
# codes = model.net.encoder(inputs)
|
||||||
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||||
|
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# generated_img,adv_latent_codes=model(inputs)
|
||||||
|
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
|
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||||
|
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||||
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
|
adv_outputs = classifier(generated_img)
|
||||||
|
clean_outputs=classifier(inputs)
|
||||||
|
adv_preds=criterion(adv_outputs,labels)
|
||||||
|
clean_preds=criterion(clean_outputs,labels)
|
||||||
|
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
||||||
|
# _, preds = torch.max(classifier(generated_img), 1)
|
||||||
|
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||||
|
|
||||||
|
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
|
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
|
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(train_dataset)
|
||||||
|
epoch_acc = running_corrects / len(train_dataset) * 100.
|
||||||
|
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
print('Evaluating!')
|
||||||
|
with torch.no_grad():
|
||||||
|
running_loss = 0. #test_dataloader
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
for i, (inputs, labels,detail_code) in enumerate(test_dataloader):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
# base_code=base_code.to(device)
|
||||||
|
detail_code=detail_code.to(device)
|
||||||
|
|
||||||
|
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||||
|
outputs = classifier(generated_img)
|
||||||
|
preds=criterion(outputs ,labels)
|
||||||
|
|
||||||
|
# running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
# epoch_loss = running_loss / len(test_dataset)
|
||||||
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
||||||
|
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
|
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue