diff --git a/model.py b/model.py index 0463e2c..9495253 100644 --- a/model.py +++ b/model.py @@ -5,25 +5,38 @@ import torch.nn as nn import os import clip - +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model, preprocess = clip.load("ViT-B/32", device=device) # class GanAttack(nn.Module): class GanAttack(nn.Module): - def __init__(self, kernel, factor=2): + def __init__(self, stylegan_generator, inverter,images_resize,prompt): super().__init__() - self.factor = factor - kernel = make_kernel(kernel) * (factor ** 2) - self.register_buffer('kernel', kernel) + self.generator = stylegan_generator + self.generator.eval() + self.inverter=inverter + self.images_resize=images_resize + text=clip.tokenize(prompt).to(device) + with torch.no_grad(): + self.prompt = model.encode_text(text) + text_len=self.prompt.shape[0] + self.mlp=nn.Sequential( + nn.Linear(text_len+512, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 512) + ) - p = kernel.shape[0] - factor - pad0 = (p + 1) // 2 + factor - 1 - pad1 = p // 2 + def forward(self, img,img_path): + _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path) + x=latent_codes + batch_size=img.shape[0] + prompt=self.prompt.repeat(batch_size).to(device) + x_prompt=torch.cat([latent_codes,prompt],dim=1) + x_prompt=self.mlp(x_prompt) + x=x_prompt+x + im,_=self.generator(x,input_is_latent=True, randomize_noise=False) - self.pad = (pad0, pad1) - def forward(self, input): - out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) - - return out \ No newline at end of file + return img,refine_images,im,x \ No newline at end of file