update to style
This commit is contained in:
parent
ef7d4347e3
commit
e9fe02f929
18
GanAttack.py
18
GanAttack.py
|
|
@ -1,6 +1,5 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
sys.path.append('./pixel2style2pixel')
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
@ -74,15 +73,18 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels) in enumerate(train_dataloader):
|
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
codes = model.net.encoder(inputs)
|
base_code=base_code.to(device)
|
||||||
|
detail_code=detail_code.to(device)
|
||||||
|
# codes = model.net.encoder(inputs)
|
||||||
|
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
generated_img,adv_latent_codes=model(inputs)
|
# generated_img,adv_latent_codes=model(inputs)
|
||||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||||
loss_clip=clip_loss(generated_img,prompt)
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
adv_outputs = classifier(generated_img)
|
adv_outputs = classifier(generated_img)
|
||||||
clean_outputs=classifier(inputs)
|
clean_outputs=classifier(inputs)
|
||||||
|
|
@ -112,11 +114,13 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0. #test_dataloader
|
running_loss = 0. #test_dataloader
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels) in enumerate(test_dataloader):
|
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
base_code=base_code.to(device)
|
||||||
|
detail_code=detail_code.to(device)
|
||||||
|
|
||||||
generated_img,adv_latent_codes=model(inputs)
|
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
preds=criterion(outputs ,labels)
|
preds=criterion(outputs ,labels)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ paths:
|
||||||
adv_embedding: pretrained_models
|
adv_embedding: pretrained_models
|
||||||
|
|
||||||
prompt: red lipstick
|
prompt: red lipstick
|
||||||
|
resolution: 1024
|
||||||
# available attributes
|
# available attributes
|
||||||
# ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks',
|
# ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks',
|
||||||
# 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling',
|
# 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling',
|
||||||
|
|
|
||||||
64
model.py
64
model.py
|
|
@ -5,10 +5,10 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
import clip
|
import clip
|
||||||
from pixel2style2pixel.models.psp import pSp
|
# from pixel2style2pixel.models.psp import pSp
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from utils import normalize
|
from utils import normalize
|
||||||
|
from stylegan.stylegan2_generator import StyleGAN2Generator
|
||||||
import hydra
|
import hydra
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -25,33 +25,46 @@ def get_prompt(cfg):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prompt = model.encode_text(text)
|
prompt = model.encode_text(text)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def get_stylegan_generator(cfg):
|
||||||
|
# model, preprocess = clip.load("RN50", device=device)
|
||||||
|
resolution=cfg.resolution
|
||||||
|
generator=StyleGAN2Generator(resolution=resolution)
|
||||||
|
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
|
||||||
|
generator.load_state_dict(checkpoint['generator'])
|
||||||
|
generator.to(device)
|
||||||
|
generator.eval()
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
# class GanAttack(nn.Module):
|
# class GanAttack(nn.Module):
|
||||||
|
|
||||||
def get_stylegan_inverter(cfg):
|
# def get_stylegan_inverter(cfg):
|
||||||
|
|
||||||
# ensure_checkpoint_exists(ckpt_path)
|
# # ensure_checkpoint_exists(ckpt_path)
|
||||||
path=cfg.paths.inverter_cfg
|
# path=cfg.paths.inverter_cfg
|
||||||
ckpt = torch.load(path, map_location='cuda:0')
|
# ckpt = torch.load(path, map_location='cuda:0')
|
||||||
opts = ckpt['opts']
|
# opts = ckpt['opts']
|
||||||
opts['checkpoint_path'] = path
|
# opts['checkpoint_path'] = path
|
||||||
if 'learn_in_w' not in opts:
|
# if 'learn_in_w' not in opts:
|
||||||
opts['learn_in_w'] = False
|
# opts['learn_in_w'] = False
|
||||||
if 'output_size' not in opts:
|
# if 'output_size' not in opts:
|
||||||
opts['output_size'] = 1024
|
# opts['output_size'] = 1024
|
||||||
|
|
||||||
net = pSp(Namespace(**opts))
|
# net = pSp(Namespace(**opts))
|
||||||
net.eval()
|
# net.eval()
|
||||||
net.cuda()
|
# net.cuda()
|
||||||
return net
|
# return net
|
||||||
|
|
||||||
class GanAttack(nn.Module):
|
class GanAttack(nn.Module):
|
||||||
def __init__(self, cfg,prompt):
|
def __init__(self, cfg,prompt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.net= get_stylegan_inverter(cfg)
|
# self.net= get_stylegan_inverter(cfg)
|
||||||
# self.generator.eval()
|
# self.generator.eval()
|
||||||
# self.inverter=inverter
|
# self.inverter=inverter
|
||||||
# self.images_resize=images_resize
|
# self.images_resize=images_resize
|
||||||
|
self.generator=get_stylegan_generator(cfg)
|
||||||
self.prompt=prompt
|
self.prompt=prompt
|
||||||
text_len=self.prompt.shape[1]
|
text_len=self.prompt.shape[1]
|
||||||
self.mlp=nn.Sequential(
|
self.mlp=nn.Sequential(
|
||||||
|
|
@ -62,24 +75,15 @@ class GanAttack(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, img):
|
def forward(self, img,basecode,detailcode):
|
||||||
codes = self.net.encoder(img)
|
|
||||||
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
|
||||||
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
|
|
||||||
# result_images = self.net.face_pool(result_images)
|
|
||||||
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
|
||||||
x=codes
|
|
||||||
batch_size=img.shape[0]
|
batch_size=img.shape[0]
|
||||||
# print(img.shape)
|
|
||||||
# print(self.prompt.shape)
|
|
||||||
# print(codes.shape)
|
|
||||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||||
# print(prompt.shape)
|
# print(prompt.shape)
|
||||||
x_prompt=torch.cat([codes,prompt],dim=2)
|
x_prompt=torch.cat([basecode,prompt],dim=2)
|
||||||
x_prompt=self.mlp(x_prompt)
|
x_prompt=self.mlp(x_prompt)
|
||||||
x=x_prompt+x
|
x=x_prompt+x
|
||||||
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False)
|
result_images=self.generator.synthesis(detailcode,randomize_noise=False,basecode=x)['image']
|
||||||
result_images = self.net.face_pool(im)
|
|
||||||
|
|
||||||
return result_images,x
|
return result_images,x
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,18 @@
|
||||||
|
# python3.7
|
||||||
|
"""Contains the synchronizing operator."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
__all__ = ['all_gather']
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(tensor):
|
||||||
|
"""Gathers tensor from all devices and does averaging."""
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||||
|
dist.all_gather(tensor_list, tensor, async_op=False)
|
||||||
|
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
|
||||||
58
utils.py
58
utils.py
|
|
@ -1,7 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
from PIL import Image
|
||||||
from torchvision import models
|
from torchvision import models
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def get_model(config):
|
def get_model(config):
|
||||||
if config.dataset == 'gender_dataset':
|
if config.dataset == 'gender_dataset':
|
||||||
|
|
@ -80,4 +82,58 @@ def set_requires_grad( nets, requires_grad=False):
|
||||||
for net in nets:
|
for net in nets:
|
||||||
if net is not None:
|
if net is not None:
|
||||||
for param in net.parameters():
|
for param in net.parameters():
|
||||||
param.requires_grad = requires_grad
|
param.requires_grad = requires_grad
|
||||||
|
|
||||||
|
def preprocess(images, channel_order='RGB'):
|
||||||
|
"""Preprocesses the input images if needed.
|
||||||
|
This function assumes the input numpy array is with shape [batch_size,
|
||||||
|
height, width, channel]. Here, `channel = 3` for color image and
|
||||||
|
`channel = 1` for grayscale image. The returned images are with shape
|
||||||
|
[batch_size, channel, height, width].
|
||||||
|
NOTE: The channel order of input images is always assumed as `RGB`.
|
||||||
|
Args:
|
||||||
|
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
|
||||||
|
Returns:
|
||||||
|
The preprocessed images with dtype `numpy.float32` and range
|
||||||
|
[-1, 1].
|
||||||
|
"""
|
||||||
|
# input : numpy, np.uint8, 0~255, RGB, BHWC
|
||||||
|
# output : numpy, np.float32, -1~1, RGB, BCHW
|
||||||
|
|
||||||
|
image_channels = 3
|
||||||
|
max_val = 1.0
|
||||||
|
min_val = -1.0
|
||||||
|
|
||||||
|
if image_channels == 3 and channel_order == 'BGR':
|
||||||
|
images = images[:, :, :, ::-1]
|
||||||
|
images = images / 255.0 * (max_val - min_val) + min_val
|
||||||
|
images = images.astype(np.float32).transpose(0, 3, 1, 2)
|
||||||
|
return images
|
||||||
|
|
||||||
|
def postprocess(images):
|
||||||
|
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
|
||||||
|
# input : tensor, -1~1, RGB, BCHW
|
||||||
|
# output : np.uint8, 0~255, BGR, BHWC
|
||||||
|
|
||||||
|
images = images.detach().cpu().numpy()
|
||||||
|
images = (images + 1.) * 255. / 2.
|
||||||
|
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
||||||
|
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
|
||||||
|
return images
|
||||||
|
|
||||||
|
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
|
||||||
|
# input : -1~1, RGB, BCHW, Tensor
|
||||||
|
# output : -1~1, RGB, BCHW, Tensor
|
||||||
|
image_target_resized = image_target.clone().cpu().numpy()
|
||||||
|
image_target_resized = (image_target_resized + 1.) * 255. / 2.
|
||||||
|
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
|
||||||
|
tmps = []
|
||||||
|
for i in range(image_target_resized.shape[0]):
|
||||||
|
tmp = image_target_resized[i]
|
||||||
|
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
|
||||||
|
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
|
||||||
|
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
|
||||||
|
tmps.append(tmp)
|
||||||
|
return torch.cat(tmps, dim=0)
|
||||||
Loading…
Reference in New Issue