From a3b086bd19db450c8e9a4c17d15e522d57768580 Mon Sep 17 00:00:00 2001 From: leewlving Date: Thu, 18 Jan 2024 17:47:08 +0800 Subject: [PATCH] add correct --- model.py | 89 +++++++++++++++++++++++++++++++++++++-------- targetedAttack.py | 21 ++++------- untargetedAttack.py | 4 +- 3 files changed, 82 insertions(+), 32 deletions(-) diff --git a/model.py b/model.py index e9e4462..a1879c7 100644 --- a/model.py +++ b/model.py @@ -84,24 +84,12 @@ class LabelEncoder(nn.Module): class TargetedGanAttack(nn.Module): def __init__(self, cfg,prompt): super().__init__() - - # self.net= get_stylegan_inverter(cfg) - # self.generator.eval() - # self.inverter=inverter - # self.images_resize=images_resize - # self.generator=get_stylegan_generator(cfg) self.prompt=prompt - text_len=self.prompt.shape[1] - # self.mlp=nn.Sequential( - # nn.Linear(text_len+512, 4096), - # nn.ReLU(inplace=True), - # nn.Linear(4096, 512) - # ) - self.mlp=nn.LazyLinear(4096) + self.mlp=nn.LazyLinear(1024) self.label_encoder=LabelEncoder() encoder_lis = [ # MNIST:1*28*28 - nn.Conv2d(29, 32, kernel_size=3, stride=1, padding=0, bias=True), + nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True), nn.InstanceNorm2d(32), nn.ReLU(), # 8*26*26 @@ -150,9 +138,76 @@ class TargetedGanAttack(nn.Module): mixed_feature=self.encoder(mixed_feature) mixed_feature=self.bottle_neck(mixed_feature) mixed_feature=self.decoder(mixed_feature) - + output=torch.chunk(mixed_feature,chunks=19,dim=1) + style=torch.cat(output[:18],dim=1) + noisy=output[-1] - return mixed_feature + return style, noisy + + + +class UnTargetedGanAttack(nn.Module): + def __init__(self, prompt): + super().__init__() + self.prompt=prompt + self.mlp=nn.LazyLinear(1024) + self.label_encoder=LabelEncoder() + encoder_lis = [ + # MNIST:1*28*28 + nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True), + nn.InstanceNorm2d(32), + nn.ReLU(), + # 8*26*26 + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True), + nn.InstanceNorm2d(64), + nn.ReLU(), + # 16*12*12 + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True), + nn.InstanceNorm2d(128), + nn.ReLU(), + # 32*5*5 + ] + bottle_neck_lis = [ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128),] + decoder_lis = [ + nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False), + nn.InstanceNorm2d(64), + nn.ReLU(), + # state size. 16 x 11 x 11 + nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False), + nn.InstanceNorm2d(32), + nn.ReLU(), + # state size. 8 x 23 x 23 + nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False), + nn.Tanh() + # state size. image_nc x 28 x 28 + ] + self.encoder = nn.Sequential(*encoder_lis) + self.bottle_neck = nn.Sequential(*bottle_neck_lis) + self.decoder = nn.Sequential(*decoder_lis) + + + def forward(self, detailcode,noise,label): + label_feat=self.label_encoder(label) + batch_size=detailcode.shape[0] + prompt=self.prompt.repeat(batch_size,18,1).to(device) + # print(prompt.shape) + x_prompt=torch.cat([detailcode,prompt],dim=2) + x_prompt=self.mlp(x_prompt) + noise1=noise[7].squeeze().repeat(batch_size,1).to(device) + # noise_temp.append(noise[7].squeeze()) + mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1) + # x=x_prompt+x + mixed_feature=self.encoder(mixed_feature) + mixed_feature=self.bottle_neck(mixed_feature) + mixed_feature=self.decoder(mixed_feature) + output=torch.chunk(mixed_feature,chunks=19,dim=1) + style=torch.cat(output[:18],dim=1) + noisy=output[-1] + + return style, noisy class CLIPLoss(torch.nn.Module): def __init__(self): @@ -212,6 +267,8 @@ class ResnetBlock(nn.Module): out = x + self.conv_block(x) return out + + # @hydra.main(version_base=None, config_path="./config", config_name="config") # def test(cfg): # prompt=torch.randn([1,1024]).to(device) diff --git a/targetedAttack.py b/targetedAttack.py index 77a7d5a..e41d8ae 100644 --- a/targetedAttack.py +++ b/targetedAttack.py @@ -27,8 +27,6 @@ def generate_labels(labels,NUM_CLASSES=307): while labels[i]==rand_v: rand_v = torch.randint(0, NUM_CLASSES-1) targets[i] = rand_v - # targets = targets.astype(np.int32) - return targets def cw_loss(outputs, labels): @@ -58,7 +56,7 @@ def main(cfg: DictConfig) -> None: classifier=get_model(cfg) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.eval() - g_ema, _=get_stylegan_generator(cfg) + g_ema=get_stylegan_generator(cfg) prompt=get_prompt(cfg) # net=get_stylegan_inverter(cfg) @@ -87,25 +85,20 @@ def main(cfg: DictConfig) -> None: inputs = inputs.to(device) labels = labels.to(device) target=generate_labels(labels) + noise=g_ema.make_noise() # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) - generated_img,adv_latent_codes=model(inputs,detail_code,target) - + style, noisy=model(inputs,detail_code,target) + detail=torch.clamp(style,-cfg.epsilon,cfg.spsilon)+detail_code + noisy=[ torch.clamp(single, -cfg.noise_epsilon,cfg.noise_epsilon) for single in noisy ] + noise[:8]=noisy + generated_img=g_ema(detail,input_is_latent=True,noise=noise,randomize_noise=False) optimizer.zero_grad() loss_vgg=loss_fn_vgg(inputs,generated_img) loss_clip=clip_loss(generated_img,prompt) adv_outputs = classifier(generated_img) loss_classifier=cw_loss(adv_outputs,target) - # clean_outputs=classifier(inputs) - # adv_preds=criterion(adv_outputs,labels) - # clean_preds=criterion(clean_outputs,labels) - # loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) - # _, preds = torch.max(classifier(generated_img), 1) - # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) - -# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier -# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() diff --git a/untargetedAttack.py b/untargetedAttack.py index a9cfd17..72704de 100644 --- a/untargetedAttack.py +++ b/untargetedAttack.py @@ -7,7 +7,7 @@ from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize -from model import GanAttack,CLIPLoss +from model import UnTargetedGanAttack,CLIPLoss import lpips from prompt import get_prompt import torch.nn.functional as F @@ -56,7 +56,7 @@ def main(cfg: DictConfig) -> None: # net=get_stylegan_inverter(cfg) - model=GanAttack(cfg,prompt).to(device) + model=UnTargetedGanAttack(cfg,prompt).to(device)