add correct
This commit is contained in:
parent
0f7414bc5f
commit
a3b086bd19
87
model.py
87
model.py
|
|
@ -84,24 +84,12 @@ class LabelEncoder(nn.Module):
|
|||
class TargetedGanAttack(nn.Module):
|
||||
def __init__(self, cfg,prompt):
|
||||
super().__init__()
|
||||
|
||||
# self.net= get_stylegan_inverter(cfg)
|
||||
# self.generator.eval()
|
||||
# self.inverter=inverter
|
||||
# self.images_resize=images_resize
|
||||
# self.generator=get_stylegan_generator(cfg)
|
||||
self.prompt=prompt
|
||||
text_len=self.prompt.shape[1]
|
||||
# self.mlp=nn.Sequential(
|
||||
# nn.Linear(text_len+512, 4096),
|
||||
# nn.ReLU(inplace=True),
|
||||
# nn.Linear(4096, 512)
|
||||
# )
|
||||
self.mlp=nn.LazyLinear(4096)
|
||||
self.mlp=nn.LazyLinear(1024)
|
||||
self.label_encoder=LabelEncoder()
|
||||
encoder_lis = [
|
||||
# MNIST:1*28*28
|
||||
nn.Conv2d(29, 32, kernel_size=3, stride=1, padding=0, bias=True),
|
||||
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ReLU(),
|
||||
# 8*26*26
|
||||
|
|
@ -150,9 +138,76 @@ class TargetedGanAttack(nn.Module):
|
|||
mixed_feature=self.encoder(mixed_feature)
|
||||
mixed_feature=self.bottle_neck(mixed_feature)
|
||||
mixed_feature=self.decoder(mixed_feature)
|
||||
output=torch.chunk(mixed_feature,chunks=19,dim=1)
|
||||
style=torch.cat(output[:18],dim=1)
|
||||
noisy=output[-1]
|
||||
|
||||
return style, noisy
|
||||
|
||||
|
||||
return mixed_feature
|
||||
|
||||
class UnTargetedGanAttack(nn.Module):
|
||||
def __init__(self, prompt):
|
||||
super().__init__()
|
||||
self.prompt=prompt
|
||||
self.mlp=nn.LazyLinear(1024)
|
||||
self.label_encoder=LabelEncoder()
|
||||
encoder_lis = [
|
||||
# MNIST:1*28*28
|
||||
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ReLU(),
|
||||
# 8*26*26
|
||||
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
|
||||
nn.InstanceNorm2d(64),
|
||||
nn.ReLU(),
|
||||
# 16*12*12
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
|
||||
nn.InstanceNorm2d(128),
|
||||
nn.ReLU(),
|
||||
# 32*5*5
|
||||
]
|
||||
bottle_neck_lis = [ResnetBlock(128),
|
||||
ResnetBlock(128),
|
||||
ResnetBlock(128),
|
||||
ResnetBlock(128),]
|
||||
decoder_lis = [
|
||||
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
|
||||
nn.InstanceNorm2d(64),
|
||||
nn.ReLU(),
|
||||
# state size. 16 x 11 x 11
|
||||
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ReLU(),
|
||||
# state size. 8 x 23 x 23
|
||||
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
|
||||
nn.Tanh()
|
||||
# state size. image_nc x 28 x 28
|
||||
]
|
||||
self.encoder = nn.Sequential(*encoder_lis)
|
||||
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
|
||||
self.decoder = nn.Sequential(*decoder_lis)
|
||||
|
||||
|
||||
def forward(self, detailcode,noise,label):
|
||||
label_feat=self.label_encoder(label)
|
||||
batch_size=detailcode.shape[0]
|
||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||
# print(prompt.shape)
|
||||
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
|
||||
# noise_temp.append(noise[7].squeeze())
|
||||
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
|
||||
# x=x_prompt+x
|
||||
mixed_feature=self.encoder(mixed_feature)
|
||||
mixed_feature=self.bottle_neck(mixed_feature)
|
||||
mixed_feature=self.decoder(mixed_feature)
|
||||
output=torch.chunk(mixed_feature,chunks=19,dim=1)
|
||||
style=torch.cat(output[:18],dim=1)
|
||||
noisy=output[-1]
|
||||
|
||||
return style, noisy
|
||||
|
||||
class CLIPLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
@ -212,6 +267,8 @@ class ResnetBlock(nn.Module):
|
|||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
# @hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
# def test(cfg):
|
||||
# prompt=torch.randn([1,1024]).to(device)
|
||||
|
|
|
|||
|
|
@ -27,8 +27,6 @@ def generate_labels(labels,NUM_CLASSES=307):
|
|||
while labels[i]==rand_v:
|
||||
rand_v = torch.randint(0, NUM_CLASSES-1)
|
||||
targets[i] = rand_v
|
||||
# targets = targets.astype(np.int32)
|
||||
|
||||
return targets
|
||||
|
||||
def cw_loss(outputs, labels):
|
||||
|
|
@ -58,7 +56,7 @@ def main(cfg: DictConfig) -> None:
|
|||
classifier=get_model(cfg)
|
||||
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||
classifier.eval()
|
||||
g_ema, _=get_stylegan_generator(cfg)
|
||||
g_ema=get_stylegan_generator(cfg)
|
||||
prompt=get_prompt(cfg)
|
||||
|
||||
# net=get_stylegan_inverter(cfg)
|
||||
|
|
@ -87,25 +85,20 @@ def main(cfg: DictConfig) -> None:
|
|||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
target=generate_labels(labels)
|
||||
noise=g_ema.make_noise()
|
||||
# base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
# codes = model.net.encoder(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code,target)
|
||||
|
||||
style, noisy=model(inputs,detail_code,target)
|
||||
detail=torch.clamp(style,-cfg.epsilon,cfg.spsilon)+detail_code
|
||||
noisy=[ torch.clamp(single, -cfg.noise_epsilon,cfg.noise_epsilon) for single in noisy ]
|
||||
noise[:8]=noisy
|
||||
generated_img=g_ema(detail,input_is_latent=True,noise=noise,randomize_noise=False)
|
||||
optimizer.zero_grad()
|
||||
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
adv_outputs = classifier(generated_img)
|
||||
loss_classifier=cw_loss(adv_outputs,target)
|
||||
# clean_outputs=classifier(inputs)
|
||||
# adv_preds=criterion(adv_outputs,labels)
|
||||
# clean_preds=criterion(clean_outputs,labels)
|
||||
# loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
||||
# _, preds = torch.max(classifier(generated_img), 1)
|
||||
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||
|
||||
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from torchvision import models
|
|||
from omegaconf import DictConfig, OmegaConf
|
||||
from data.dataset import get_dataset,get_adv_dataset
|
||||
from utils import get_model,set_requires_grad,unnormalize
|
||||
from model import GanAttack,CLIPLoss
|
||||
from model import UnTargetedGanAttack,CLIPLoss
|
||||
import lpips
|
||||
from prompt import get_prompt
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -56,7 +56,7 @@ def main(cfg: DictConfig) -> None:
|
|||
|
||||
# net=get_stylegan_inverter(cfg)
|
||||
|
||||
model=GanAttack(cfg,prompt).to(device)
|
||||
model=UnTargetedGanAttack(cfg,prompt).to(device)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue