GanAttack/GanAttack.py

35 lines
955 B
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset
from utils import get_model
from model.GanInverter.models.stylegan2.model import Generator
from model.GanInverter.inference.two_stage_inference import TwoStageInference
import time
import hydra
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg):
# ensure_checkpoint_exists(ckpt_path)
ckpt_path=cfg.paths.stylegan
g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096)
return g_ema, mean_latent
def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg
inverter=TwoStageInference(opts=path)
return inverter.inverse