add targeted

This commit is contained in:
leewlving 2024-01-17 17:37:25 +08:00
parent 43c7bd0fa3
commit 1c446414d0
1 changed files with 6 additions and 10 deletions

View File

@ -22,16 +22,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg): def get_stylegan_generator(cfg):
generator=Generator(1024, 512, 8)
# ensure_checkpoint_exists(ckpt_path) checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
ckpt_path=cfg.paths.stylegan generator.load_state_dict(checkpoint['g_ema'])
g_ema = Generator(1024, 512, 8) generator.to(device)
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) generator.eval()
g_ema.eval() return generator
g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096)
return g_ema, mean_latent