import torch import torchvision from torchvision import datasets, models, transforms import torch.nn as nn import torch.nn.functional as F import os import clip # from pixel2style2pixel.models.psp import pSp from argparse import Namespace from utils import normalize from stylegan.model import Generator import hydra from omegaconf import DictConfig, OmegaConf import sys import os import numpy as np device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_prompt(cfg): model, preprocess = clip.load("RN50", device=device) prompt=cfg.prompt text=clip.tokenize(prompt).to(device) with torch.no_grad(): prompt = model.encode_text(text) return prompt def get_stylegan_generator(cfg): generator=Generator(1024, 512, 8) checkpoint = torch.load(cfg.paths.stylegan, map_location=device) generator.load_state_dict(checkpoint['g_ema']) generator.to(device) generator.eval() return generator # class GanAttack(nn.Module): # def get_stylegan_inverter(cfg): # # ensure_checkpoint_exists(ckpt_path) # path=cfg.paths.inverter_cfg # ckpt = torch.load(path, map_location='cuda:0') # opts = ckpt['opts'] # opts['checkpoint_path'] = path # if 'learn_in_w' not in opts: # opts['learn_in_w'] = False # if 'output_size' not in opts: # opts['output_size'] = 1024 # net = pSp(Namespace(**opts)) # net.eval() # net.cuda() # return net class GanAttack(nn.Module): def __init__(self, cfg,prompt): super().__init__() # self.net= get_stylegan_inverter(cfg) # self.generator.eval() # self.inverter=inverter # self.images_resize=images_resize self.generator=get_stylegan_generator(cfg) self.prompt=prompt text_len=self.prompt.shape[1] self.mlp=nn.Sequential( nn.Linear(text_len+512, 4096), nn.ReLU(inplace=True), nn.Linear(4096, 512) ) self.noise_mlp=nn.Sequential( nn.Linear(10,1024), nn.ReLU(inplace=True), nn.Linear(1024,512) ) self.noises=self.generator.make_noise() for i,j in enumerate(self.noises): if i>9: j.detach() def forward(self, img,detailcode): batch_size=img.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x adv=self.noise_mlp(self.noises) self.noises=adv+self.noises result_images, _ =self.generator([detailcode],input_is_latent=True, randomize_noise=False,noises=self.noises) return result_images,x class CLIPLoss(torch.nn.Module): def __init__(self): super(CLIPLoss, self).__init__() self.model, self.preprocess = clip.load("RN50", device="cuda") self.model.eval() self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) def forward(self, image, text): image=normalize(image) image = self.face_pool(image) similarity = 1 - self.model(image, text)[0]/ 100 return similarity # class VggLoss(torch.nn.Module): # def __init__(self): # super(VggLoss, self).__init__() # self.model=models.vgg11(pretrained=True) # self.model.features=nn.Sequential() # # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) # def forward(self, image1, image2): # # image=normalize(image) # with torch.no_grad: # feature1=self.model(image1) # feature2=self.model(image2) # feature1=torch.flatten(feature1) # feature2=torch.flatten(feature2) # similarity = F.cosine_similarity(feature1,feature2) # return similarity @hydra.main(version_base=None, config_path="./config", config_name="config") def test(cfg): prompt=torch.randn([1,1024]).to(device) model=GanAttack(cfg,prompt).to(device) data=torch.randn([2,3,256,256]).to(device) result_images,x=model(data) print(result_images.shape) print(x.shape) if __name__ == "__main__": test()