From 1c446414d00e9d3c8adea0e5180760a655e257ca Mon Sep 17 00:00:00 2001 From: leewlving Date: Wed, 17 Jan 2024 17:37:25 +0800 Subject: [PATCH] add targeted --- targetedAttack.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/targetedAttack.py b/targetedAttack.py index 99201b8..47d99aa 100644 --- a/targetedAttack.py +++ b/targetedAttack.py @@ -22,16 +22,12 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def get_stylegan_generator(cfg): - - # ensure_checkpoint_exists(ckpt_path) - ckpt_path=cfg.paths.stylegan - g_ema = Generator(1024, 512, 8) - g_ema.load_state_dict(torch.load(ckpt_path)["g_ema"], strict=False) - g_ema.eval() - g_ema = g_ema.cuda() - mean_latent = g_ema.mean_latent(4096) - - return g_ema, mean_latent + generator=Generator(1024, 512, 8) + checkpoint = torch.load(cfg.paths.stylegan, map_location=device) + generator.load_state_dict(checkpoint['g_ema']) + generator.to(device) + generator.eval() + return generator