import torch import torchvision from torchvision import datasets, models, transforms import torch.nn as nn import torch.nn.functional as F import os import clip # from pixel2style2pixel.models.psp import pSp from argparse import Namespace from utils import normalize from stylegan.stylegan2_generator import StyleGAN2Generator import hydra from omegaconf import DictConfig, OmegaConf import sys import os sys.path.append('./pixel2style2pixel') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_prompt(cfg): model, preprocess = clip.load("RN50", device=device) prompt=cfg.prompt text=clip.tokenize(prompt).to(device) with torch.no_grad(): prompt = model.encode_text(text) return prompt def get_stylegan_generator(cfg): # model, preprocess = clip.load("RN50", device=device) resolution=cfg.resolution generator=StyleGAN2Generator(resolution=resolution) checkpoint = torch.load(cfg.paths.stylegan, map_location=device) generator.load_state_dict(checkpoint['generator']) generator.to(device) generator.eval() return generator # class GanAttack(nn.Module): # def get_stylegan_inverter(cfg): # # ensure_checkpoint_exists(ckpt_path) # path=cfg.paths.inverter_cfg # ckpt = torch.load(path, map_location='cuda:0') # opts = ckpt['opts'] # opts['checkpoint_path'] = path # if 'learn_in_w' not in opts: # opts['learn_in_w'] = False # if 'output_size' not in opts: # opts['output_size'] = 1024 # net = pSp(Namespace(**opts)) # net.eval() # net.cuda() # return net class GanAttack(nn.Module): def __init__(self, cfg,prompt): super().__init__() # self.net= get_stylegan_inverter(cfg) # self.generator.eval() # self.inverter=inverter # self.images_resize=images_resize self.generator=get_stylegan_generator(cfg) self.prompt=prompt text_len=self.prompt.shape[1] self.mlp=nn.Sequential( nn.Linear(text_len+512, 4096), nn.ReLU(inplace=True), nn.Linear(4096, 512) ) def forward(self, img,basecode,detailcode): batch_size=img.shape[0] prompt=self.prompt.repeat(batch_size,18,1).to(device) # print(prompt.shape) x_prompt=torch.cat([basecode,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x result_images=self.generator.synthesis(detailcode,randomize_noise=False,basecode=x)['image'] return result_images,x class CLIPLoss(torch.nn.Module): def __init__(self): super(CLIPLoss, self).__init__() self.model, self.preprocess = clip.load("RN50", device="cuda") self.model.eval() self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224)) # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) def forward(self, image, text): image=normalize(image) image = self.face_pool(image) similarity = 1 - self.model(image, text)[0]/ 100 return similarity class VggLoss(torch.nn.Module): def __init__(self): super(VggLoss, self).__init__() self.model=models.vgg11(pretrained=True) self.model.features=nn.Sequential() # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1) # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1) def forward(self, image1, image2): # image=normalize(image) with torch.no_grad: feature1=self.model(image1) feature2=self.model(image2) feature1=torch.flatten(feature1) feature2=torch.flatten(feature2) similarity = F.cosine_similarity(feature1,feature2) return similarity @hydra.main(version_base=None, config_path="./config", config_name="config") def test(cfg): prompt=torch.randn([1,1024]).to(device) model=GanAttack(cfg,prompt).to(device) data=torch.randn([2,3,256,256]).to(device) result_images,x=model(data) print(result_images.shape) print(x.shape) if __name__ == "__main__": test()