GanAttack/model.py

48 lines
1.4 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import os
import clip
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_prompt(cfg):
model, preprocess = clip.load("ViT-B/32", device=device)
prompt=cfg.prompt
text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
return prompt
# class GanAttack(nn.Module):
class GanAttack(nn.Module):
def __init__(self, stylegan_generator, inverter,images_resize,prompt):
super().__init__()
self.generator = stylegan_generator
self.generator.eval()
self.inverter=inverter
self.images_resize=images_resize
self.prompt=prompt
text_len=self.prompt.shape[0]
self.mlp=nn.Sequential(
nn.Linear(text_len+512, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 512)
)
def forward(self, img,img_path):
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=latent_codes
batch_size=img.shape[0]
prompt=self.prompt.repeat(batch_size).to(device)
x_prompt=torch.cat([latent_codes,prompt],dim=1)
x_prompt=self.mlp(x_prompt)
x=x_prompt+x
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
return img,refine_images,im,x