add noise
This commit is contained in:
parent
62fcf3ce63
commit
da57cfb967
|
|
@ -61,7 +61,9 @@ def main(cfg: DictConfig) -> None:
|
|||
# set_requires_grad(model.mlp.parameters())
|
||||
for p in (model.mlp.parameters()):
|
||||
p.requires_grad =True
|
||||
optimizer = optim.SGD(model.mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||
for p in (model.noise_mlp.parameters()):
|
||||
p.requires_grad =True
|
||||
optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
|
|
|
|||
18
model.py
18
model.py
|
|
@ -70,8 +70,15 @@ class GanAttack(nn.Module):
|
|||
nn.ReLU(inplace=True),
|
||||
nn.Linear(4096, 512)
|
||||
)
|
||||
# basecode_layer = int(np.log2(cfg.basecode_spatial_size) - 2) * 2
|
||||
# self.basecode_layer=basecode_layer = f'x{basecode_layer-1:02d}'
|
||||
self.noise_mlp=nn.Sequential(
|
||||
nn.Linear(10,1024),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(1024,512)
|
||||
)
|
||||
self.noises=self.generator.make_noise()
|
||||
for i,j in enumerate(self.noises):
|
||||
if i>9:
|
||||
j.detach()
|
||||
|
||||
|
||||
|
||||
|
|
@ -82,9 +89,10 @@ class GanAttack(nn.Module):
|
|||
x_prompt=torch.cat([detailcode,prompt],dim=2)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
noises=self.generator.make_noise()
|
||||
result_images=self.generator.synthesis(detailcode,randomize_noise=False,
|
||||
basecode_layer=self.basecode_layer,basecode=x)['image']
|
||||
adv=self.noise_mlp(self.noises)
|
||||
self.noises=adv+self.noises
|
||||
result_images, _ =self.generator([detailcode],input_is_latent=True,
|
||||
randomize_noise=False,noises=self.noises)
|
||||
|
||||
|
||||
return result_images,x
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
||||
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue