35 lines
955 B
Python
35 lines
955 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from data.dataset import get_dataset
|
|
from utils import get_model
|
|
from model.GanInverter.models.stylegan2.model import Generator
|
|
from model.GanInverter.inference.two_stage_inference import TwoStageInference
|
|
import time
|
|
|
|
import hydra
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def get_stylegan_generator(cfg):
|
|
|
|
# ensure_checkpoint_exists(ckpt_path)
|
|
ckpt_path=cfg.paths.stylegan
|
|
g_ema = Generator(1024, 512, 8)
|
|
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
|
g_ema.eval()
|
|
g_ema = g_ema.cuda()
|
|
mean_latent = g_ema.mean_latent(4096)
|
|
|
|
return g_ema, mean_latent
|
|
|
|
def get_stylegan_inverter(cfg):
|
|
|
|
# ensure_checkpoint_exists(ckpt_path)
|
|
path=cfg.paths.inverter_cfg
|
|
inverter=TwoStageInference(opts=path)
|
|
|
|
return inverter.inverse |