optimize the model
This commit is contained in:
parent
cc87c09e86
commit
86fe6261b8
23
GanAttack.py
23
GanAttack.py
|
|
@ -11,9 +11,7 @@ from utils import get_model,set_requires_grad,unnormalize
|
|||
# from GanInverter.inference.two_stage_inference import TwoStageInference
|
||||
# from GanInverter.models.stylegan2.model import Generator
|
||||
# from models import GanAttack
|
||||
from argparse import Namespace
|
||||
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
||||
from pixel2style2pixel.models.psp import pSp
|
||||
# from pixel2style2pixel.models.stylegan2.model import Generator
|
||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -39,22 +37,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||
|
||||
# return g_ema, mean_latent
|
||||
|
||||
def get_stylegan_inverter(cfg):
|
||||
|
||||
# ensure_checkpoint_exists(ckpt_path)
|
||||
path=cfg.paths.inverter_cfg
|
||||
ckpt = torch.load(path, map_location='cuda:0')
|
||||
opts = ckpt['opts']
|
||||
opts['checkpoint_path'] = path
|
||||
if 'learn_in_w' not in opts:
|
||||
opts['learn_in_w'] = False
|
||||
if 'output_size' not in opts:
|
||||
opts['output_size'] = 1024
|
||||
|
||||
net = pSp(Namespace(**opts))
|
||||
net.eval()
|
||||
net.cuda()
|
||||
return net
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
|
|
@ -67,8 +50,8 @@ def main(cfg: DictConfig) -> None:
|
|||
# g_ema, _=get_stylegan_generator(cfg)
|
||||
prompt=get_prompt(cfg)
|
||||
|
||||
net=get_stylegan_inverter(cfg)
|
||||
model=GanAttack(net,prompt).to(device)
|
||||
# net=get_stylegan_inverter(cfg)
|
||||
model=GanAttack(cfg,prompt).to(device)
|
||||
|
||||
|
||||
|
||||
|
|
@ -93,7 +76,7 @@ def main(cfg: DictConfig) -> None:
|
|||
for i, (inputs, labels) in enumerate(train_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
codes = net.encoder(inputs)
|
||||
codes = model.net.encoder(inputs)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
optimizer.zero_grad()
|
||||
generated_img,adv_latent_codes=model(inputs)
|
||||
|
|
|
|||
28
model.py
28
model.py
|
|
@ -5,6 +5,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import os
|
||||
import clip
|
||||
from pixel2style2pixel.models.psp import pSp
|
||||
from argparse import Namespace
|
||||
from utils import normalize
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
|
@ -19,23 +21,41 @@ def get_prompt(cfg):
|
|||
return prompt
|
||||
# class GanAttack(nn.Module):
|
||||
|
||||
def get_stylegan_inverter(cfg):
|
||||
|
||||
# ensure_checkpoint_exists(ckpt_path)
|
||||
path=cfg.paths.inverter_cfg
|
||||
ckpt = torch.load(path, map_location='cuda:0')
|
||||
opts = ckpt['opts']
|
||||
opts['checkpoint_path'] = path
|
||||
if 'learn_in_w' not in opts:
|
||||
opts['learn_in_w'] = False
|
||||
if 'output_size' not in opts:
|
||||
opts['output_size'] = 1024
|
||||
|
||||
net = pSp(Namespace(**opts))
|
||||
net.eval()
|
||||
net.cuda()
|
||||
return net
|
||||
|
||||
class GanAttack(nn.Module):
|
||||
def __init__(self, net,prompt):
|
||||
def __init__(self, cfg,prompt):
|
||||
super().__init__()
|
||||
|
||||
self.net= net
|
||||
self.net= get_stylegan_inverter(cfg)
|
||||
# self.generator.eval()
|
||||
# self.inverter=inverter
|
||||
# self.images_resize=images_resize
|
||||
self.prompt=prompt
|
||||
text_len=self.prompt.shape[1]
|
||||
self.mlp=nn.Sequential(
|
||||
nn.Linear(1024+512, 4096),
|
||||
nn.Linear(text_len+512, 4096),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, 512)
|
||||
)
|
||||
|
||||
|
||||
|
||||
def forward(self, img):
|
||||
codes = self.net.encoder(img)
|
||||
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
||||
|
|
@ -48,7 +68,7 @@ class GanAttack(nn.Module):
|
|||
# print(self.prompt.shape)
|
||||
# print(codes.shape)
|
||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||
print(prompt.shape)
|
||||
# print(prompt.shape)
|
||||
x_prompt=torch.cat([codes,prompt],dim=2)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
|
|
|
|||
Loading…
Reference in New Issue