import torch import torchvision from torchvision import datasets, models, transforms import torch.nn as nn import torch.nn.functional as F import os import clip # from pixel2style2pixel.models.psp import pSp from argparse import Namespace from utils import normalize from stylegan.model import Generator import hydra from omegaconf import DictConfig, OmegaConf import sys import os import numpy as np device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_prompt(cfg): model, preprocess = clip.load("RN50", device=device) prompt=cfg.prompt text=clip.tokenize(prompt).to(device) with torch.no_grad(): prompt = model.encode_text(text) return prompt def get_stylegan_generator(cfg): generator=Generator(1024, 512, 8) checkpoint = torch.load(cfg.paths.stylegan, map_location=device) generator.load_state_dict(checkpoint['g_ema']) generator.to(device) generator.eval() return generator class LabelEncoder(nn.Module): def __init__(self, nf=307): super(LabelEncoder, self).__init__() self.nf = nf curr_dim = nf self.size = 64 self.fc = nn.Sequential( # nn.Linear(512, 512), nn.ReLU(True), nn.Linear(512, curr_dim * self.size * self.size), nn.ReLU(True)) transform = [] for i in range(4): transform += [ nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False), # nn.Upsample(scale_factor=(2, 2)), # nn.Conv2d(curr_dim, curr_dim//2, kernel_size=3, padding=1, bias=False), nn.InstanceNorm2d(curr_dim // 2, affine=False), nn.ReLU(inplace=True) ] curr_dim = curr_dim // 2 transform += [ nn.Conv2d(curr_dim, 3, kernel_size=3, stride=1, padding=1, bias=False) ] self.transform = nn.Sequential(*transform) def forward(self, label_feature): label_feature = self.fc(label_feature) label_feature = label_feature.view(label_feature.size(0), self.nf, self.size, self.size) label_feature = self.transform(label_feature) # mixed_feature = label_feature + image # mixed_feature = torch.cat((label_feature, image), dim=1) return label_feature class TargetedGanAttack(nn.Module): def __init__(self, cfg,prompt): super().__init__() self.prompt=prompt self.mlp=nn.LazyLinear(1024) self.label_encoder=LabelEncoder() encoder_lis = [ # MNIST:1*28*28 nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True), nn.InstanceNorm2d(32), nn.ReLU(), # 8*26*26 nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True), nn.InstanceNorm2d(64), nn.ReLU(), # 16*12*12 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True), nn.InstanceNorm2d(128), nn.ReLU(), # 32*5*5 ] bottle_neck_lis = [ResnetBlock(128), ResnetBlock(128), ResnetBlock(128), ResnetBlock(128),] decoder_lis = [ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), # state size. 16 x 11 x 11 nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), # state size. 8 x 23 x 23 nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False), nn.Tanh() # state size. image_nc x 28 x 28 ] self.encoder = nn.Sequential(*encoder_lis) self.bottle_neck = nn.Sequential(*bottle_neck_lis) self.decoder = nn.Sequential(*decoder_lis) def forward(self, detailcode,noise,label): label_feat=self.label_encoder(label) batch_size=detailcode.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) noise1=noise[7].squeeze().repeat(batch_size,1).to(device) # noise_temp.append(noise[7].squeeze()) mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1) # x=x_prompt+x mixed_feature=self.encoder(mixed_feature) mixed_feature=self.bottle_neck(mixed_feature) mixed_feature=self.decoder(mixed_feature) output=torch.chunk(mixed_feature,chunks=19,dim=1) style=torch.cat(output[:18],dim=1) noisy=output[-1] return style, noisy class UnTargetedGanAttack(nn.Module): def __init__(self, prompt): super().__init__() self.prompt=prompt self.mlp=nn.LazyLinear(1024) self.label_encoder=LabelEncoder() encoder_lis = [ # MNIST:1*28*28 nn.LazyConv2d( 32, kernel_size=3, stride=1, padding=0, bias=True), nn.InstanceNorm2d(32), nn.ReLU(), # 8*26*26 nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0, bias=True), nn.InstanceNorm2d(64), nn.ReLU(), # 16*12*12 nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0, bias=True), nn.InstanceNorm2d(128), nn.ReLU(), # 32*5*5 ] bottle_neck_lis = [ResnetBlock(128), ResnetBlock(128), ResnetBlock(128), ResnetBlock(128),] decoder_lis = [ nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(64), nn.ReLU(), # state size. 16 x 11 x 11 nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=0, bias=False), nn.InstanceNorm2d(32), nn.ReLU(), # state size. 8 x 23 x 23 nn.ConvTranspose2d(32, 29, kernel_size=6, stride=1, padding=0, bias=False), nn.Tanh() # state size. image_nc x 28 x 28 ] self.encoder = nn.Sequential(*encoder_lis) self.bottle_neck = nn.Sequential(*bottle_neck_lis) self.decoder = nn.Sequential(*decoder_lis) def forward(self, detailcode,noise,label): label_feat=self.label_encoder(label) batch_size=detailcode.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) noise1=noise[7].squeeze().repeat(batch_size,1).to(device) # noise_temp.append(noise[7].squeeze()) mixed_feature = torch.cat([x_prompt,noise1,label_feat], dim=1) # x=x_prompt+x mixed_feature=self.encoder(mixed_feature) mixed_feature=self.bottle_neck(mixed_feature) mixed_feature=self.decoder(mixed_feature) output=torch.chunk(mixed_feature,chunks=19,dim=1) style=torch.cat(output[:18],dim=1) noisy=output[-1] return style, noisy class CLIPLoss(torch.nn.Module): def __init__(self): super(CLIPLoss, self).__init__() self.model, self.preprocess = clip.load("RN50", device="cuda") self.model.eval() self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) def forward(self, image, text): image=normalize(image) image = self.face_pool(image) similarity = 1 - self.model(image, text)[0]/ 100 return similarity class ResnetBlock(nn.Module): def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=False): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out # @hydra.main(version_base=None, config_path="./config", config_name="config") # def test(cfg): # prompt=torch.randn([1,1024]).to(device) # model=GanAttack(cfg,prompt).to(device) # data=torch.randn([2,3,256,256]).to(device) # result_images,x=model(data) # print(result_images.shape) # print(x.shape) # if __name__ == "__main__": # test()