add targeted

This commit is contained in:
leewlving 2024-01-17 17:37:25 +08:00
parent 43c7bd0fa3
commit 1c446414d0
1 changed files with 6 additions and 10 deletions

View File

@ -22,16 +22,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_stylegan_generator(cfg):
# ensure_checkpoint_exists(ckpt_path)
ckpt_path=cfg.paths.stylegan
g_ema = Generator(1024, 512, 8)
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
g_ema.eval()
g_ema = g_ema.cuda()
mean_latent = g_ema.mean_latent(4096)
return g_ema, mean_latent
generator=Generator(1024, 512, 8)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['g_ema'])
generator.to(device)
generator.eval()
return generator