GanAttack/model.py

145 lines
4.6 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import clip
# from pixel2style2pixel.models.psp import pSp
from argparse import Namespace
from utils import normalize
from stylegan.model import Generator
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import os
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_prompt(cfg):
model, preprocess = clip.load("RN50", device=device)
prompt=cfg.prompt
text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
return prompt
def get_stylegan_generator(cfg):
generator=Generator(1024, 512, 8)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['g_ema'])
generator.to(device)
generator.eval()
return generator
# class GanAttack(nn.Module):
# def get_stylegan_inverter(cfg):
# # ensure_checkpoint_exists(ckpt_path)
# path=cfg.paths.inverter_cfg
# ckpt = torch.load(path, map_location='cuda:0')
# opts = ckpt['opts']
# opts['checkpoint_path'] = path
# if 'learn_in_w' not in opts:
# opts['learn_in_w'] = False
# if 'output_size' not in opts:
# opts['output_size'] = 1024
# net = pSp(Namespace(**opts))
# net.eval()
# net.cuda()
# return net
class GanAttack(nn.Module):
def __init__(self, cfg,prompt):
super().__init__()
# self.net= get_stylegan_inverter(cfg)
# self.generator.eval()
# self.inverter=inverter
# self.images_resize=images_resize
self.generator=get_stylegan_generator(cfg)
self.prompt=prompt
text_len=self.prompt.shape[1]
self.mlp=nn.Sequential(
nn.Linear(text_len+512, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 512)
)
self.noise_mlp=nn.Sequential(
nn.Linear(10,1024),
nn.ReLU(inplace=True),
nn.Linear(1024,512)
)
self.noises=self.generator.make_noise()
for i,j in enumerate(self.noises):
if i>9:
j.detach()
def forward(self, img,detailcode):
batch_size=img.shape[0]
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([detailcode,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
x=x_prompt+x
adv=self.noise_mlp(self.noises)
self.noises=adv+self.noises
result_images, _ =self.generator([detailcode],input_is_latent=True,
randomize_noise=False,noises=self.noises)
return result_images,x
class CLIPLoss(torch.nn.Module):
def __init__(self):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("RN50", device="cuda")
self.model.eval()
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image, text):
image=normalize(image)
image = self.face_pool(image)
similarity = 1 - self.model(image, text)[0]/ 100
return similarity
# class VggLoss(torch.nn.Module):
# def __init__(self):
# super(VggLoss, self).__init__()
# self.model=models.vgg11(pretrained=True)
# self.model.features=nn.Sequential()
# # self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# # self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
# def forward(self, image1, image2):
# # image=normalize(image)
# with torch.no_grad:
# feature1=self.model(image1)
# feature2=self.model(image2)
# feature1=torch.flatten(feature1)
# feature2=torch.flatten(feature2)
# similarity = F.cosine_similarity(feature1,feature2)
# return similarity
@hydra.main(version_base=None, config_path="./config", config_name="config")
def test(cfg):
prompt=torch.randn([1,1024]).to(device)
model=GanAttack(cfg,prompt).to(device)
data=torch.randn([2,3,256,256]).to(device)
result_images,x=model(data)
print(result_images.shape)
print(x.shape)
if __name__ == "__main__":
test()