282 lines
10 KiB
Python
282 lines
10 KiB
Python
import torch
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import os
|
|
import clip
|
|
# from pixel2style2pixel.models.psp import pSp
|
|
from argparse import Namespace
|
|
from utils import normalize
|
|
from stylegan.model import Generator
|
|
import hydra
|
|
from omegaconf import DictConfig, OmegaConf
|
|
import sys
|
|
import os
|
|
import numpy as np
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def get_prompt(cfg):
|
|
model, preprocess = clip.load("RN50", device=device)
|
|
prompt=cfg.prompt
|
|
text=clip.tokenize(prompt).to(device)
|
|
with torch.no_grad():
|
|
prompt = model.encode_text(text)
|
|
return prompt
|
|
|
|
def get_stylegan_generator(cfg):
|
|
generator=Generator(1024, 512, 8)
|
|
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
|
|
generator.load_state_dict(checkpoint['g_ema'])
|
|
generator.to(device)
|
|
generator.eval()
|
|
return generator
|
|
|
|
|
|
class LabelEncoder(nn.Module):
|
|
def __init__(self, nf=307):
|
|
super(LabelEncoder, self).__init__()
|
|
self.nf = nf
|
|
curr_dim = nf
|
|
self.size = 64
|
|
|
|
self.fc = nn.Sequential(
|
|
# nn.Linear(512, 512), nn.ReLU(True),
|
|
nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True))
|
|
|
|
transform = []
|
|
for i in range(4):
|
|
transform += [
|
|
nn.ConvTranspose2d(curr_dim,
|
|
curr_dim // 2,
|
|
kernel_size=4,
|
|
stride=2,
|
|
padding=1,
|
|
bias=False),
|
|
# nn.Upsample(scale_factor=(2, 2)),
|
|
# nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False),
|
|
nn.InstanceNorm2d(curr_dim // 2, affine=False),
|
|
nn.ReLU(inplace=True)
|
|
]
|
|
curr_dim = curr_dim // 2
|
|
|
|
transform += [
|
|
nn.Conv2d(curr_dim,
|
|
3,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=False)
|
|
]
|
|
self.transform = nn.Sequential(*transform)
|
|
|
|
def forward(self, label_feature):
|
|
label_feature = self.fc(label_feature)
|
|
label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size)
|
|
label_feature = self.transform(label_feature)
|
|
|
|
# mixed_feature = label_feature + image
|
|
# mixed_feature = torch.cat((label_feature, image), dim=1)
|
|
return label_feature
|
|
|
|
class TargetedGanAttack(nn.Module):
|
|
def __init__(self, cfg,prompt):
|
|
super().__init__()
|
|
self.prompt=prompt
|
|
self.mlp=nn.LazyLinear(1024)
|
|
self.label_encoder=LabelEncoder()
|
|
encoder_lis = [
|
|
# MNIST:1*28*28
|
|
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
|
|
nn.InstanceNorm2d(32),
|
|
nn.ReLU(),
|
|
# 8*26*26
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
|
|
nn.InstanceNorm2d(64),
|
|
nn.ReLU(),
|
|
# 16*12*12
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
|
|
nn.InstanceNorm2d(128),
|
|
nn.ReLU(),
|
|
# 32*5*5
|
|
]
|
|
bottle_neck_lis = [ResnetBlock(128),
|
|
ResnetBlock(128),
|
|
ResnetBlock(128),
|
|
ResnetBlock(128),]
|
|
decoder_lis = [
|
|
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
|
|
nn.InstanceNorm2d(64),
|
|
nn.ReLU(),
|
|
# state size. 16 x 11 x 11
|
|
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
|
|
nn.InstanceNorm2d(32),
|
|
nn.ReLU(),
|
|
# state size. 8 x 23 x 23
|
|
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
|
|
nn.Tanh()
|
|
# state size. image_nc x 28 x 28
|
|
]
|
|
self.encoder = nn.Sequential(*encoder_lis)
|
|
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
|
|
self.decoder = nn.Sequential(*decoder_lis)
|
|
|
|
|
|
def forward(self, detailcode,noise,label):
|
|
label_feat=self.label_encoder(label)
|
|
batch_size=detailcode.shape[0]
|
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
|
# print(prompt.shape)
|
|
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
|
x_prompt=self.mlp(x_prompt)
|
|
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
|
|
# noise_temp.append(noise[7].squeeze())
|
|
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
|
|
# x=x_prompt+x
|
|
mixed_feature=self.encoder(mixed_feature)
|
|
mixed_feature=self.bottle_neck(mixed_feature)
|
|
mixed_feature=self.decoder(mixed_feature)
|
|
output=torch.chunk(mixed_feature,chunks=19,dim=1)
|
|
style=torch.cat(output[:18],dim=1)
|
|
noisy=output[-1]
|
|
|
|
return style, noisy
|
|
|
|
|
|
|
|
class UnTargetedGanAttack(nn.Module):
|
|
def __init__(self, prompt):
|
|
super().__init__()
|
|
self.prompt=prompt
|
|
self.mlp=nn.LazyLinear(1024)
|
|
self.label_encoder=LabelEncoder()
|
|
encoder_lis = [
|
|
# MNIST:1*28*28
|
|
nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True),
|
|
nn.InstanceNorm2d(32),
|
|
nn.ReLU(),
|
|
# 8*26*26
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True),
|
|
nn.InstanceNorm2d(64),
|
|
nn.ReLU(),
|
|
# 16*12*12
|
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True),
|
|
nn.InstanceNorm2d(128),
|
|
nn.ReLU(),
|
|
# 32*5*5
|
|
]
|
|
bottle_neck_lis = [ResnetBlock(128),
|
|
ResnetBlock(128),
|
|
ResnetBlock(128),
|
|
ResnetBlock(128),]
|
|
decoder_lis = [
|
|
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False),
|
|
nn.InstanceNorm2d(64),
|
|
nn.ReLU(),
|
|
# state size. 16 x 11 x 11
|
|
nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False),
|
|
nn.InstanceNorm2d(32),
|
|
nn.ReLU(),
|
|
# state size. 8 x 23 x 23
|
|
nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False),
|
|
nn.Tanh()
|
|
# state size. image_nc x 28 x 28
|
|
]
|
|
self.encoder = nn.Sequential(*encoder_lis)
|
|
self.bottle_neck = nn.Sequential(*bottle_neck_lis)
|
|
self.decoder = nn.Sequential(*decoder_lis)
|
|
|
|
|
|
def forward(self, detailcode,noise,label):
|
|
label_feat=self.label_encoder(label)
|
|
batch_size=detailcode.shape[0]
|
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
|
# print(prompt.shape)
|
|
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
|
x_prompt=self.mlp(x_prompt)
|
|
noise1=noise[7].squeeze().repeat(batch_size,1).to(device)
|
|
# noise_temp.append(noise[7].squeeze())
|
|
mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1)
|
|
# x=x_prompt+x
|
|
mixed_feature=self.encoder(mixed_feature)
|
|
mixed_feature=self.bottle_neck(mixed_feature)
|
|
mixed_feature=self.decoder(mixed_feature)
|
|
output=torch.chunk(mixed_feature,chunks=19,dim=1)
|
|
style=torch.cat(output[:18],dim=1)
|
|
noisy=output[-1]
|
|
|
|
return style, noisy
|
|
|
|
class CLIPLoss(torch.nn.Module):
|
|
def __init__(self):
|
|
super(CLIPLoss, self).__init__()
|
|
self.model, self.preprocess = clip.load("RN50", device="cuda")
|
|
self.model.eval()
|
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
|
|
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
|
|
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
|
|
|
|
def forward(self, image, text):
|
|
image=normalize(image)
|
|
image = self.face_pool(image)
|
|
similarity = 1 - self.model(image, text)[0]/ 100
|
|
return similarity
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False):
|
|
super(ResnetBlock, self).__init__()
|
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
conv_block = []
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
norm_layer(dim),
|
|
nn.ReLU(True)]
|
|
if use_dropout:
|
|
conv_block += [nn.Dropout(0.5)]
|
|
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
norm_layer(dim)]
|
|
|
|
return nn.Sequential(*conv_block)
|
|
|
|
def forward(self, x):
|
|
out = x + self.conv_block(x)
|
|
return out
|
|
|
|
|
|
|
|
# @hydra.main(version_base=None, config_path="./config", config_name="config")
|
|
# def test(cfg):
|
|
# prompt=torch.randn([1,1024]).to(device)
|
|
# model=GanAttack(cfg,prompt).to(device)
|
|
# data=torch.randn([2,3,256,256]).to(device)
|
|
# result_images,x=model(data)
|
|
# print(result_images.shape)
|
|
# print(x.shape)
|
|
# if __name__ == "__main__":
|
|
# test()
|
|
|