add model
This commit is contained in:
parent
9cb2d6bbf8
commit
037d945fe3
39
model.py
39
model.py
|
|
@ -5,25 +5,38 @@ import torch.nn as nn
|
|||
import os
|
||||
import clip
|
||||
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
# class GanAttack(nn.Module):
|
||||
|
||||
class GanAttack(nn.Module):
|
||||
def __init__(self, kernel, factor=2):
|
||||
def __init__(self, stylegan_generator, inverter,images_resize,prompt):
|
||||
super().__init__()
|
||||
|
||||
self.factor = factor
|
||||
kernel = make_kernel(kernel) * (factor ** 2)
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.generator = stylegan_generator
|
||||
self.generator.eval()
|
||||
self.inverter=inverter
|
||||
self.images_resize=images_resize
|
||||
text=clip.tokenize(prompt).to(device)
|
||||
with torch.no_grad():
|
||||
self.prompt = model.encode_text(text)
|
||||
text_len=self.prompt.shape[0]
|
||||
self.mlp=nn.Sequential(
|
||||
nn.Linear(text_len+512, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, 512)
|
||||
)
|
||||
|
||||
p = kernel.shape[0] - factor
|
||||
|
||||
pad0 = (p + 1) // 2 + factor - 1
|
||||
pad1 = p // 2
|
||||
def forward(self, img,img_path):
|
||||
_, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||
x=latent_codes
|
||||
batch_size=img.shape[0]
|
||||
prompt=self.prompt.repeat(batch_size).to(device)
|
||||
x_prompt=torch.cat([latent_codes,prompt],dim=1)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
im,_=self.generator(x,input_is_latent=True, randomize_noise=False)
|
||||
|
||||
self.pad = (pad0, pad1)
|
||||
|
||||
def forward(self, input):
|
||||
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
||||
|
||||
return out
|
||||
return img,refine_images,im,x
|
||||
Loading…
Reference in New Issue