GanAttack/model.py

282 lines
10 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import clip
# from pixel2style2pixel.models.psp import pSp
from argparse import Namespace
from utils import normalize
from stylegan.model import Generator
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import os
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_prompt(cfg):
model, preprocess = clip.load("RN50", device=device)
prompt=cfg.prompt
text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
return prompt
def get_stylegan_generator(cfg):
generator=Generator(1024, 512, 8)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['g_ema'])
generator.to(device)
generator.eval()
return generator
class LabelEncoder(nn.Module):
def __init__(self, nf=307):
super(LabelEncoder, self).__init__()
self.nf = nf
curr_dim = nf
self.size = 64
self.fc = nn.Sequential(
# nn.Linear(512, 512), nn.ReLU(True),
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
transform = []
for i in range(4):
transform += [
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias=False),
# nn.Upsample(scale_factor=(2, 2)),
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(curr_dim // 2, affine=False),
nn.ReLU(inplace=True)
]
curr_dim = curr_dim // 2
transform += [
nn.Conv2d(curr_dim,
3,
kernel_size=3,
stride=1,
padding=1,
bias=False)
]
self.transform = nn.Sequential(*transform)
def forward(self, label_feature):
label_feature = self.fc(label_feature)
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
label_feature = self.transform(label_feature)
# mixed_feature = label_feature + image
# mixed_feature = torch.cat((label_feature, image), dim=1)
return label_feature
class TargetedGanAttack(nn.Module):
def __init__(self, cfg,prompt):
super().__init__()
self.prompt=prompt
self.mlp=nn.LazyLinear(1024)
self.label_encoder=LabelEncoder()
encoder_lis = [
# MNIST:1*28*28
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(32),
nn.ReLU(),
# 8*26*26
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(64),
nn.ReLU(),
# 16*12*12
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(128),
nn.ReLU(),
# 32*5*5
]
bottle_neck_lis = [ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),]
decoder_lis = [
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(64),
nn.ReLU(),
# state size. 16 x 11 x 11
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(32),
nn.ReLU(),
# state size. 8 x 23 x 23
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
nn.Tanh()
# state size. image_nc x 28 x 28
]
self.encoder = nn.Sequential(*encoder_lis)
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
self.decoder = nn.Sequential(*decoder_lis)
def forward(self, detailcode,noise,label):
label_feat=self.label_encoder(label)
batch_size=detailcode.shape[0]
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([detailcode,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
# noise_temp.append(noise[7].squeeze())
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
# x=x_prompt+x
mixed_feature=self.encoder(mixed_feature)
mixed_feature=self.bottle_neck(mixed_feature)
mixed_feature=self.decoder(mixed_feature)
output=torch.chunk(mixed_feature,chunks=19,dim=1)
style=torch.cat(output[:18],dim=1)
noisy=output[-1]
return style, noisy
class UnTargetedGanAttack(nn.Module):
def __init__(self, prompt):
super().__init__()
self.prompt=prompt
self.mlp=nn.LazyLinear(1024)
self.label_encoder=LabelEncoder()
encoder_lis = [
# MNIST:1*28*28
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
nn.InstanceNorm2d(32),
nn.ReLU(),
# 8*26*26
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(64),
nn.ReLU(),
# 16*12*12
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
nn.InstanceNorm2d(128),
nn.ReLU(),
# 32*5*5
]
bottle_neck_lis = [ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128),]
decoder_lis = [
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(64),
nn.ReLU(),
# state size. 16 x 11 x 11
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
nn.InstanceNorm2d(32),
nn.ReLU(),
# state size. 8 x 23 x 23
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
nn.Tanh()
# state size. image_nc x 28 x 28
]
self.encoder = nn.Sequential(*encoder_lis)
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
self.decoder = nn.Sequential(*decoder_lis)
def forward(self, detailcode,noise,label):
label_feat=self.label_encoder(label)
batch_size=detailcode.shape[0]
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([detailcode,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
# noise_temp.append(noise[7].squeeze())
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
# x=x_prompt+x
mixed_feature=self.encoder(mixed_feature)
mixed_feature=self.bottle_neck(mixed_feature)
mixed_feature=self.decoder(mixed_feature)
output=torch.chunk(mixed_feature,chunks=19,dim=1)
style=torch.cat(output[:18],dim=1)
noisy=output[-1]
return style, noisy
class CLIPLoss(torch.nn.Module):
def __init__(self):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("RN50", device="cuda")
self.model.eval()
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image, text):
image=normalize(image)
image = self.face_pool(image)
similarity = 1 - self.model(image, text)[0]/ 100
return similarity
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# @hydra.main(version_base=None, config_path="./config", config_name="config")
# def test(cfg):
# prompt=torch.randn([1,1024]).to(device)
# model=GanAttack(cfg,prompt).to(device)
# data=torch.randn([2,3,256,256]).to(device)
# result_images,x=model(data)
# print(result_images.shape)
# print(x.shape)
# if __name__ == "__main__":
# test()