diff --git a/GanAttack.py b/GanAttack.py index 31a8bd6..d688dab 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -11,9 +11,7 @@ from utils import get_model,set_requires_grad,unnormalize # from GanInverter.inference.two_stage_inference import TwoStageInference # from GanInverter.models.stylegan2.model import Generator # from models import GanAttack -from argparse import Namespace # from pixel2style2pixel.scripts.align_all_parallel import align_face -from pixel2style2pixel.models.psp import pSp # from pixel2style2pixel.models.stylegan2.model import Generator from model import GanAttack,CLIPLoss,VggLoss,get_prompt import torch.nn.functional as F @@ -39,22 +37,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # return g_ema, mean_latent -def get_stylegan_inverter(cfg): - # ensure_checkpoint_exists(ckpt_path) - path=cfg.paths.inverter_cfg - ckpt = torch.load(path, map_location='cuda:0') - opts = ckpt['opts'] - opts['checkpoint_path'] = path - if 'learn_in_w' not in opts: - opts['learn_in_w'] = False - if 'output_size' not in opts: - opts['output_size'] = 1024 - - net = pSp(Namespace(**opts)) - net.eval() - net.cuda() - return net @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: @@ -67,8 +50,8 @@ def main(cfg: DictConfig) -> None: # g_ema, _=get_stylegan_generator(cfg) prompt=get_prompt(cfg) - net=get_stylegan_inverter(cfg) - model=GanAttack(net,prompt).to(device) + # net=get_stylegan_inverter(cfg) + model=GanAttack(cfg,prompt).to(device) @@ -93,7 +76,7 @@ def main(cfg: DictConfig) -> None: for i, (inputs, labels) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) - codes = net.encoder(inputs) + codes = model.net.encoder(inputs) # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) optimizer.zero_grad() generated_img,adv_latent_codes=model(inputs) diff --git a/model.py b/model.py index 1feeee4..372d952 100644 --- a/model.py +++ b/model.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F import os import clip +from pixel2style2pixel.models.psp import pSp +from argparse import Namespace from utils import normalize device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -18,22 +20,40 @@ def get_prompt(cfg): prompt = model.encode_text(text) return prompt # class GanAttack(nn.Module): + +def get_stylegan_inverter(cfg): + + # ensure_checkpoint_exists(ckpt_path) + path=cfg.paths.inverter_cfg + ckpt = torch.load(path, map_location='cuda:0') + opts = ckpt['opts'] + opts['checkpoint_path'] = path + if 'learn_in_w' not in opts: + opts['learn_in_w'] = False + if 'output_size' not in opts: + opts['output_size'] = 1024 + + net = pSp(Namespace(**opts)) + net.eval() + net.cuda() + return net class GanAttack(nn.Module): - def __init__(self, net,prompt): + def __init__(self, cfg,prompt): super().__init__() - self.net= net + self.net= get_stylegan_inverter(cfg) # self.generator.eval() # self.inverter=inverter # self.images_resize=images_resize self.prompt=prompt text_len=self.prompt.shape[1] self.mlp=nn.Sequential( - nn.Linear(1024+512, 4096), + nn.Linear(text_len+512, 4096), nn.ReLU(inplace=True), nn.Linear(4096, 512) ) + def forward(self, img): @@ -48,7 +68,7 @@ class GanAttack(nn.Module): # print(self.prompt.shape) # print(codes.shape) prompt=self.prompt.repeat(batch_size,18,1).to(device) - print(prompt.shape) + # print(prompt.shape) x_prompt=torch.cat([codes,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x