GanAttack/model.py

131 lines
4.3 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import clip
from pixel2style2pixel.models.psp import pSp
from argparse import Namespace
from utils import normalize
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import os
sys.path.append('./pixel2style2pixel')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_prompt(cfg):
model, preprocess = clip.load("RN50", device=device)
prompt=cfg.prompt
text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
return prompt
# class GanAttack(nn.Module):
def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg
ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts']
opts['checkpoint_path'] = path
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
net = pSp(Namespace(**opts))
net.eval()
net.cuda()
return net
class GanAttack(nn.Module):
def __init__(self, cfg,prompt):
super().__init__()
self.net= get_stylegan_inverter(cfg)
# self.generator.eval()
# self.inverter=inverter
# self.images_resize=images_resize
self.prompt=prompt
text_len=self.prompt.shape[1]
self.mlp=nn.Sequential(
nn.Linear(text_len+512, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 512)
)
def forward(self, img):
codes = self.net.encoder(img)
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
# result_images = self.net.face_pool(result_images)
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=codes
batch_size=img.shape[0]
# print(img.shape)
# print(self.prompt.shape)
# print(codes.shape)
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([codes,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
x=x_prompt+x
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False)
result_images = self.net.face_pool(im)
return result_images,x
class CLIPLoss(torch.nn.Module):
def __init__(self):
super(CLIPLoss, self).__init__()
self.model, self.preprocess = clip.load("RN50", device="cuda")
self.model.eval()
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image, text):
image=normalize(image)
image = self.face_pool(image)
similarity = 1 - self.model(image, text)[0]/ 100
return similarity
class VggLoss(torch.nn.Module):
def __init__(self):
super(VggLoss, self).__init__()
self.model=models.vgg11(pretrained=True)
self.model.features=nn.Sequential()
# self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device="cuda").view(1,3,1,1)
# self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device="cuda").view(1,3,1,1)
def forward(self, image1, image2):
# image=normalize(image)
with torch.no_grad:
feature1=self.model(image1)
feature2=self.model(image2)
feature1=torch.flatten(feature1)
feature2=torch.flatten(feature2)
similarity = F.cosine_similarity(feature1,feature2)
return similarity
@hydra.main(version_base=None, config_path="./config", config_name="config")
def test(cfg):
prompt=torch.randn([1,1024]).to(device)
model=GanAttack(cfg,prompt).to(device)
data=torch.randn([2,3,256,256]).to(device)
result_images,x=model(data)
print(result_images.shape)
print(x.shape)
if __name__ == "__main__":
test()