add targeted
This commit is contained in:
parent
43c7bd0fa3
commit
1c446414d0
|
|
@ -22,16 +22,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
def get_stylegan_generator(cfg):
|
def get_stylegan_generator(cfg):
|
||||||
|
generator=Generator(1024, 512, 8)
|
||||||
# ensure_checkpoint_exists(ckpt_path)
|
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
|
||||||
ckpt_path=cfg.paths.stylegan
|
generator.load_state_dict(checkpoint['g_ema'])
|
||||||
g_ema = Generator(1024, 512, 8)
|
generator.to(device)
|
||||||
g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False)
|
generator.eval()
|
||||||
g_ema.eval()
|
return generator
|
||||||
g_ema = g_ema.cuda()
|
|
||||||
mean_latent = g_ema.mean_latent(4096)
|
|
||||||
|
|
||||||
return g_ema, mean_latent
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue