update to style
This commit is contained in:
parent
ef7d4347e3
commit
e9fe02f929
18
GanAttack.py
18
GanAttack.py
|
|
@ -1,6 +1,5 @@
|
|||
import sys
|
||||
import os
|
||||
sys.path.append('./pixel2style2pixel')
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
|
@ -74,15 +73,18 @@ def main(cfg: DictConfig) -> None:
|
|||
running_loss = 0
|
||||
running_corrects = 0
|
||||
|
||||
for i, (inputs, labels) in enumerate(train_dataloader):
|
||||
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
codes = model.net.encoder(inputs)
|
||||
base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
# codes = model.net.encoder(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
optimizer.zero_grad()
|
||||
generated_img,adv_latent_codes=model(inputs)
|
||||
# generated_img,adv_latent_codes=model(inputs)
|
||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
||||
loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
adv_outputs = classifier(generated_img)
|
||||
clean_outputs=classifier(inputs)
|
||||
|
|
@ -112,11 +114,13 @@ def main(cfg: DictConfig) -> None:
|
|||
running_loss = 0. #test_dataloader
|
||||
running_corrects = 0
|
||||
|
||||
for i, (inputs, labels) in enumerate(test_dataloader):
|
||||
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
|
||||
generated_img,adv_latent_codes=model(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||
outputs = classifier(generated_img)
|
||||
preds=criterion(outputs ,labels)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ paths:
|
|||
adv_embedding: pretrained_models
|
||||
|
||||
prompt: red lipstick
|
||||
resolution: 1024
|
||||
# available attributes
|
||||
# ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks',
|
||||
# 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling',
|
||||
|
|
|
|||
64
model.py
64
model.py
|
|
@ -5,10 +5,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import os
|
||||
import clip
|
||||
from pixel2style2pixel.models.psp import pSp
|
||||
# from pixel2style2pixel.models.psp import pSp
|
||||
from argparse import Namespace
|
||||
from utils import normalize
|
||||
|
||||
from stylegan.stylegan2_generator import StyleGAN2Generator
|
||||
import hydra
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
import sys
|
||||
|
|
@ -25,33 +25,46 @@ def get_prompt(cfg):
|
|||
with torch.no_grad():
|
||||
prompt = model.encode_text(text)
|
||||
return prompt
|
||||
|
||||
def get_stylegan_generator(cfg):
|
||||
# model, preprocess = clip.load("RN50", device=device)
|
||||
resolution=cfg.resolution
|
||||
generator=StyleGAN2Generator(resolution=resolution)
|
||||
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
|
||||
generator.load_state_dict(checkpoint['generator'])
|
||||
generator.to(device)
|
||||
generator.eval()
|
||||
return generator
|
||||
|
||||
|
||||
# class GanAttack(nn.Module):
|
||||
|
||||
def get_stylegan_inverter(cfg):
|
||||
# def get_stylegan_inverter(cfg):
|
||||
|
||||
# ensure_checkpoint_exists(ckpt_path)
|
||||
path=cfg.paths.inverter_cfg
|
||||
ckpt = torch.load(path, map_location='cuda:0')
|
||||
opts = ckpt['opts']
|
||||
opts['checkpoint_path'] = path
|
||||
if 'learn_in_w' not in opts:
|
||||
opts['learn_in_w'] = False
|
||||
if 'output_size' not in opts:
|
||||
opts['output_size'] = 1024
|
||||
# # ensure_checkpoint_exists(ckpt_path)
|
||||
# path=cfg.paths.inverter_cfg
|
||||
# ckpt = torch.load(path, map_location='cuda:0')
|
||||
# opts = ckpt['opts']
|
||||
# opts['checkpoint_path'] = path
|
||||
# if 'learn_in_w' not in opts:
|
||||
# opts['learn_in_w'] = False
|
||||
# if 'output_size' not in opts:
|
||||
# opts['output_size'] = 1024
|
||||
|
||||
net = pSp(Namespace(**opts))
|
||||
net.eval()
|
||||
net.cuda()
|
||||
return net
|
||||
# net = pSp(Namespace(**opts))
|
||||
# net.eval()
|
||||
# net.cuda()
|
||||
# return net
|
||||
|
||||
class GanAttack(nn.Module):
|
||||
def __init__(self, cfg,prompt):
|
||||
super().__init__()
|
||||
|
||||
self.net= get_stylegan_inverter(cfg)
|
||||
# self.net= get_stylegan_inverter(cfg)
|
||||
# self.generator.eval()
|
||||
# self.inverter=inverter
|
||||
# self.images_resize=images_resize
|
||||
self.generator=get_stylegan_generator(cfg)
|
||||
self.prompt=prompt
|
||||
text_len=self.prompt.shape[1]
|
||||
self.mlp=nn.Sequential(
|
||||
|
|
@ -62,24 +75,15 @@ class GanAttack(nn.Module):
|
|||
|
||||
|
||||
|
||||
def forward(self, img):
|
||||
codes = self.net.encoder(img)
|
||||
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
||||
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||
# result_images = self.net.face_pool(result_images)
|
||||
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
|
||||
x=codes
|
||||
def forward(self, img,basecode,detailcode):
|
||||
batch_size=img.shape[0]
|
||||
# print(img.shape)
|
||||
# print(self.prompt.shape)
|
||||
# print(codes.shape)
|
||||
prompt=self.prompt.repeat(batch_size,18,1).to(device)
|
||||
# print(prompt.shape)
|
||||
x_prompt=torch.cat([codes,prompt],dim=2)
|
||||
x_prompt=torch.cat([basecode,prompt],dim=2)
|
||||
x_prompt=self.mlp(x_prompt)
|
||||
x=x_prompt+x
|
||||
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False)
|
||||
result_images = self.net.face_pool(im)
|
||||
result_images=self.generator.synthesis(detailcode,randomize_noise=False,basecode=x)['image']
|
||||
|
||||
|
||||
return result_images,x
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,18 @@
|
|||
# python3.7
|
||||
"""Contains the synchronizing operator."""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
__all__ = ['all_gather']
|
||||
|
||||
|
||||
def all_gather(tensor):
|
||||
"""Gathers tensor from all devices and does averaging."""
|
||||
if not dist.is_initialized():
|
||||
return tensor
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, async_op=False)
|
||||
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
|
||||
56
utils.py
56
utils.py
|
|
@ -1,7 +1,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from PIL import Image
|
||||
from torchvision import models
|
||||
import numpy as np
|
||||
|
||||
def get_model(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
|
|
@ -81,3 +83,57 @@ def set_requires_grad( nets, requires_grad=False):
|
|||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
def preprocess(images, channel_order='RGB'):
|
||||
"""Preprocesses the input images if needed.
|
||||
This function assumes the input numpy array is with shape [batch_size,
|
||||
height, width, channel]. Here, `channel = 3` for color image and
|
||||
`channel = 1` for grayscale image. The returned images are with shape
|
||||
[batch_size, channel, height, width].
|
||||
NOTE: The channel order of input images is always assumed as `RGB`.
|
||||
Args:
|
||||
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
|
||||
Returns:
|
||||
The preprocessed images with dtype `numpy.float32` and range
|
||||
[-1, 1].
|
||||
"""
|
||||
# input : numpy, np.uint8, 0~255, RGB, BHWC
|
||||
# output : numpy, np.float32, -1~1, RGB, BCHW
|
||||
|
||||
image_channels = 3
|
||||
max_val = 1.0
|
||||
min_val = -1.0
|
||||
|
||||
if image_channels == 3 and channel_order == 'BGR':
|
||||
images = images[:, :, :, ::-1]
|
||||
images = images / 255.0 * (max_val - min_val) + min_val
|
||||
images = images.astype(np.float32).transpose(0, 3, 1, 2)
|
||||
return images
|
||||
|
||||
def postprocess(images):
|
||||
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
|
||||
# input : tensor, -1~1, RGB, BCHW
|
||||
# output : np.uint8, 0~255, BGR, BHWC
|
||||
|
||||
images = images.detach().cpu().numpy()
|
||||
images = (images + 1.) * 255. / 2.
|
||||
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
||||
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
|
||||
return images
|
||||
|
||||
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
|
||||
# input : -1~1, RGB, BCHW, Tensor
|
||||
# output : -1~1, RGB, BCHW, Tensor
|
||||
image_target_resized = image_target.clone().cpu().numpy()
|
||||
image_target_resized = (image_target_resized + 1.) * 255. / 2.
|
||||
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
|
||||
|
||||
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
|
||||
tmps = []
|
||||
for i in range(image_target_resized.shape[0]):
|
||||
tmp = image_target_resized[i]
|
||||
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
|
||||
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
|
||||
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
|
||||
tmps.append(tmp)
|
||||
return torch.cat(tmps, dim=0)
|
||||
Loading…
Reference in New Issue