diff --git a/GanAttack.py b/GanAttack.py index b86dd42..ea62f4d 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -61,7 +61,9 @@ def main(cfg: DictConfig) -> None: # set_requires_grad(model.mlp.parameters()) for p in (model.mlp.parameters()): p.requires_grad =True - optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) + for p in (model.noise_mlp.parameters()): + p.requires_grad =True + optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): diff --git a/model.py b/model.py index 25a4ee0..095ed0f 100644 --- a/model.py +++ b/model.py @@ -70,8 +70,15 @@ class GanAttack(nn.Module): nn.ReLU(inplace=True), nn.Linear(4096, 512) ) - # basecode_layer = int(np.log2(cfg.basecode_spatial_size) - 2) * 2 - # self.basecode_layer=basecode_layer = f'x{basecode_layer-1:02d}' + self.noise_mlp=nn.Sequential( + nn.Linear(10,1024), + nn.ReLU(inplace=True), + nn.Linear(1024,512) + ) + self.noises=self.generator.make_noise() + for i,j in enumerate(self.noises): + if i>9: + j.detach() @@ -82,9 +89,10 @@ class GanAttack(nn.Module): x_prompt=torch.cat([detailcode,prompt],dim=2) x_prompt=self.mlp(x_prompt) x=x_prompt+x - noises=self.generator.make_noise() - result_images=self.generator.synthesis(detailcode,randomize_noise=False, - basecode_layer=self.basecode_layer,basecode=x)['image'] + adv=self.noise_mlp(self.noises) + self.noises=adv+self.noises + result_images, _ =self.generator([detailcode],input_is_latent=True, + randomize_noise=False,noises=self.noises) return result_images,x diff --git a/stylegan/model.py b/stylegan/model.py index 13acc7f..c8c302a 100755 --- a/stylegan/model.py +++ b/stylegan/model.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.nn import functional as F -from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d +from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d import numpy as np torch.manual_seed(0)