From 43c7bd0fa31d8caea22ce1be880060bd9a1a5f83 Mon Sep 17 00:00:00 2001 From: leewlving Date: Wed, 17 Jan 2024 17:12:26 +0800 Subject: [PATCH] add target --- model.py | 216 +++++++++++++++++++++++++++++++--------------- targetedAttack.py | 152 ++++++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+), 68 deletions(-) diff --git a/model.py b/model.py index 095ed0f..e9e4462 100644 --- a/model.py +++ b/model.py @@ -35,26 +35,53 @@ def get_stylegan_generator(cfg): return generator -# class GanAttack(nn.Module): +class LabelEncoder(nn.Module): + def __init__(self, nf=307): + super(LabelEncoder, self).__init__() + self.nf = nf + curr_dim = nf + self.size = 64 -# def get_stylegan_inverter(cfg): + self.fc = nn.Sequential( + # nn.Linear(512, 512), nn.ReLU(True), + nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True)) -# # ensure_checkpoint_exists(ckpt_path) -# path=cfg.paths.inverter_cfg -# ckpt = torch.load(path, map_location='cuda:0') -# opts = ckpt['opts'] -# opts['checkpoint_path'] = path -# if 'learn_in_w' not in opts: -# opts['learn_in_w'] = False -# if 'output_size' not in opts: -# opts['output_size'] = 1024 - -# net = pSp(Namespace(**opts)) -# net.eval() -# net.cuda() -# return net + transform = [] + for i in range(4): + transform += [ + nn.ConvTranspose2d(curr_dim, + curr_dim // 2, + kernel_size=4, + stride=2, + padding=1, + bias=False), + # nn.Upsample(scale_factor=(2, 2)), + # nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(curr_dim // 2, affine=False), + nn.ReLU(inplace=True) + ] + curr_dim = curr_dim // 2 + + transform += [ + nn.Conv2d(curr_dim, + 3, + kernel_size=3, + stride=1, + padding=1, + bias=False) + ] + self.transform = nn.Sequential(*transform) + + def forward(self, label_feature): + label_feature = self.fc(label_feature) + label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size) + label_feature = self.transform(label_feature) + + # mixed_feature = label_feature + image + # mixed_feature = torch.cat((label_feature, image), dim=1) + return label_feature -class GanAttack(nn.Module): +class TargetedGanAttack(nn.Module): def __init__(self, cfg,prompt): super().__init__() @@ -62,40 +89,70 @@ class GanAttack(nn.Module): # self.generator.eval() # self.inverter=inverter # self.images_resize=images_resize - self.generator=get_stylegan_generator(cfg) + # self.generator=get_stylegan_generator(cfg) self.prompt=prompt text_len=self.prompt.shape[1] - self.mlp=nn.Sequential( - nn.Linear(text_len+512, 4096), - nn.ReLU(inplace=True), - nn.Linear(4096, 512) - ) - self.noise_mlp=nn.Sequential( - nn.Linear(10,1024), - nn.ReLU(inplace=True), - nn.Linear(1024,512) - ) - self.noises=self.generator.make_noise() - for i,j in enumerate(self.noises): - if i>9: - j.detach() - + # self.mlp=nn.Sequential( + # nn.Linear(text_len+512, 4096), + # nn.ReLU(inplace=True), + # nn.Linear(4096, 512) + # ) + self.mlp=nn.LazyLinear(4096) + self.label_encoder=LabelEncoder() + encoder_lis = [ + # MNIST:1*28*28 + nn.Conv2d(29, 32, kernel_size=3, stride=1, padding=0, bias=True), + nn.InstanceNorm2d(32), + nn.ReLU(), + # 8*26*26 + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True), + nn.InstanceNorm2d(64), + nn.ReLU(), + # 16*12*12 + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True), + nn.InstanceNorm2d(128), + nn.ReLU(), + # 32*5*5 + ] + bottle_neck_lis = [ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128), + ResnetBlock(128),] + decoder_lis = [ + nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False), + nn.InstanceNorm2d(64), + nn.ReLU(), + # state size. 16 x 11 x 11 + nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False), + nn.InstanceNorm2d(32), + nn.ReLU(), + # state size. 8 x 23 x 23 + nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False), + nn.Tanh() + # state size. image_nc x 28 x 28 + ] + self.encoder = nn.Sequential(*encoder_lis) + self.bottle_neck = nn.Sequential(*bottle_neck_lis) + self.decoder = nn.Sequential(*decoder_lis) - def forward(self, img,detailcode): - batch_size=img.shape[0] + def forward(self, detailcode,noise,label): + label_feat=self.label_encoder(label) + batch_size=detailcode.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) - x=x_prompt+x - adv=self.noise_mlp(self.noises) - self.noises=adv+self.noises - result_images, _ =self.generator([detailcode],input_is_latent=True, - randomize_noise=False,noises=self.noises) + noise1=noise[7].squeeze().repeat(batch_size,1).to(device) + # noise_temp.append(noise[7].squeeze()) + mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1) + # x=x_prompt+x + mixed_feature=self.encoder(mixed_feature) + mixed_feature=self.bottle_neck(mixed_feature) + mixed_feature=self.decoder(mixed_feature) - return result_images,x + return mixed_feature class CLIPLoss(torch.nn.Module): def __init__(self): @@ -113,33 +170,56 @@ class CLIPLoss(torch.nn.Module): return similarity -# class VggLoss(torch.nn.Module): -# def __init__(self): -# super(VggLoss, self).__init__() -# self.model=models.vgg11(pretrained=True) -# self.model.features=nn.Sequential() - -# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) -# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) -# def forward(self, image1, image2): -# # image=normalize(image) -# with torch.no_grad: -# feature1=self.model(image1) -# feature2=self.model(image2) -# feature1=torch.flatten(feature1) -# feature2=torch.flatten(feature2) -# similarity = F.cosine_similarity(feature1,feature2) -# return similarity + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out -@hydra.main(version_base=None, config_path="./config", config_name="config") -def test(cfg): - prompt=torch.randn([1,1024]).to(device) - model=GanAttack(cfg,prompt).to(device) - data=torch.randn([2,3,256,256]).to(device) - result_images,x=model(data) - print(result_images.shape) - print(x.shape) -if __name__ == "__main__": - test() +# @hydra.main(version_base=None, config_path="./config", config_name="config") +# def test(cfg): +# prompt=torch.randn([1,1024]).to(device) +# model=GanAttack(cfg,prompt).to(device) +# data=torch.randn([2,3,256,256]).to(device) +# result_images,x=model(data) +# print(result_images.shape) +# print(x.shape) +# if __name__ == "__main__": +# test() \ No newline at end of file diff --git a/targetedAttack.py b/targetedAttack.py index e69de29..99201b8 100644 --- a/targetedAttack.py +++ b/targetedAttack.py @@ -0,0 +1,152 @@ +import sys +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models +from omegaconf import DictConfig, OmegaConf +from data.dataset import get_dataset,get_adv_dataset +from utils import get_model,set_requires_grad,unnormalize +from stylegan.model import Generator +from model import CLIPLoss,TargetedGanAttack +import lpips +from prompt import get_prompt +import torch.nn.functional as F +import time + +import hydra + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + + + +def get_stylegan_generator(cfg): + + # ensure_checkpoint_exists(ckpt_path) + ckpt_path=cfg.paths.stylegan + g_ema = Generator(1024, 512, 8) + g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) + g_ema.eval() + g_ema = g_ema.cuda() + mean_latent = g_ema.mean_latent(4096) + + return g_ema, mean_latent + + + +@hydra.main(version_base=None, config_path="./config", config_name="config") +def main(cfg: DictConfig) -> None: + model=get_model(cfg) + train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) + + classifier=get_model(cfg) + classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) + classifier.eval() + g_ema, _=get_stylegan_generator(cfg) + prompt=get_prompt(cfg) + + # net=get_stylegan_inverter(cfg) + + model=TargetedGanAttack(cfg,prompt).to(device) + + + + num_epochs = cfg.optim.num_epochs + criterion = nn.CrossEntropyLoss() + max_loss=nn.MarginRankingLoss(0.1) + clip_loss=CLIPLoss().to(device) + loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) + # vgg_loss=VggLoss().to(device) +# summary(model, input_size = (3, 256, 256), batch_size = 5) +# set_requires_grad(model.mlp.parameters()) + # for p in (model.mlp.parameters()): + # p.requires_grad =True + # for p in (model.noise_mlp.parameters()): + # p.requires_grad =True + model.train() + optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) + start_time = time.time() + + for epoch in range(num_epochs): + model.train() + + running_loss = 0 + running_corrects = 0 + + for i, (inputs, labels,detail_code) in enumerate(train_dataloader): + inputs = inputs.to(device) + labels = labels.to(device) + # base_code=base_code.to(device) + detail_code=detail_code.to(device) + # codes = model.net.encoder(inputs) + generated_img,adv_latent_codes=model(inputs,detail_code) + # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) + optimizer.zero_grad() + # generated_img,adv_latent_codes=model(inputs) + # loss_vgg=vgg_loss(inputs,generated_img) + # loss_l1=F.l1_loss(base_code,adv_latent_codes) + loss_vgg=loss_fn_vgg(inputs,generated_img) + loss_clip=clip_loss(generated_img,prompt) + adv_outputs = classifier(generated_img) + clean_outputs=classifier(inputs) + adv_preds=criterion(adv_outputs,labels) + clean_preds=criterion(clean_outputs,labels) + loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds)) + # _, preds = torch.max(classifier(generated_img), 1) + # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) + +# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier +# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier + loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / len(train_dataset) + epoch_acc = running_corrects / len(train_dataset) * 100. + print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) + + + model.eval() + print('Evaluating!') + with torch.no_grad(): + running_loss = 0. #test_dataloader + running_corrects = 0 + + for i, (inputs, labels,detail_code) in enumerate(test_dataloader): + inputs = inputs.to(device) + labels = labels.to(device) + # base_code=base_code.to(device) + detail_code=detail_code.to(device) + + generated_img,adv_latent_codes=model(inputs,detail_code) + outputs = classifier(generated_img) + preds=criterion(outputs ,labels) + + # running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + # epoch_loss = running_loss / len(test_dataset) + epoch_acc = running_corrects / len(test_dataset) * 100. + print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time)) + + save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) + torch.save(model.state_dict(), save_path) + + + + + + + + + + + + + +if __name__ == "__main__": + main() \ No newline at end of file