update to style

This commit is contained in:
Li Wenyun 2023-12-12 17:36:08 +08:00
parent ef7d4347e3
commit e9fe02f929
6 changed files with 1125 additions and 38 deletions

View File

@ -1,6 +1,5 @@
import sys import sys
import os import os
sys.path.append('./pixel2style2pixel')
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
@ -74,15 +73,18 @@ def main(cfg: DictConfig) -> None:
running_loss = 0 running_loss = 0
running_corrects = 0 running_corrects = 0
for i, (inputs, labels) in enumerate(train_dataloader): for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
codes = model.net.encoder(inputs) base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path) # _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad() optimizer.zero_grad()
generated_img,adv_latent_codes=model(inputs) # generated_img,adv_latent_codes=model(inputs)
# loss_vgg=vgg_loss(inputs,generated_img) # loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(codes,adv_latent_codes) loss_l1=F.l1_loss(base_code,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt) loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img) adv_outputs = classifier(generated_img)
clean_outputs=classifier(inputs) clean_outputs=classifier(inputs)
@ -112,11 +114,13 @@ def main(cfg: DictConfig) -> None:
running_loss = 0. #test_dataloader running_loss = 0. #test_dataloader
running_corrects = 0 running_corrects = 0
for i, (inputs, labels) in enumerate(test_dataloader): for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs) generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
outputs = classifier(generated_img) outputs = classifier(generated_img)
preds=criterion(outputs ,labels) preds=criterion(outputs ,labels)

View File

@ -19,6 +19,7 @@ paths:
adv_embedding: pretrained_models adv_embedding: pretrained_models
prompt: red lipstick prompt: red lipstick
resolution: 1024
# available attributes # available attributes
# ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks', # ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks',
# 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling', # 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling',

View File

@ -5,10 +5,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import os import os
import clip import clip
from pixel2style2pixel.models.psp import pSp # from pixel2style2pixel.models.psp import pSp
from argparse import Namespace from argparse import Namespace
from utils import normalize from utils import normalize
from stylegan.stylegan2_generator import StyleGAN2Generator
import hydra import hydra
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
import sys import sys
@ -25,33 +25,46 @@ def get_prompt(cfg):
with torch.no_grad(): with torch.no_grad():
prompt = model.encode_text(text) prompt = model.encode_text(text)
return prompt return prompt
def get_stylegan_generator(cfg):
# model, preprocess = clip.load("RN50", device=device)
resolution=cfg.resolution
generator=StyleGAN2Generator(resolution=resolution)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['generator'])
generator.to(device)
generator.eval()
return generator
# class GanAttack(nn.Module): # class GanAttack(nn.Module):
def get_stylegan_inverter(cfg): # def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path) # # ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg # path=cfg.paths.inverter_cfg
ckpt = torch.load(path, map_location='cuda:0') # ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts'] # opts = ckpt['opts']
opts['checkpoint_path'] = path # opts['checkpoint_path'] = path
if 'learn_in_w' not in opts: # if 'learn_in_w' not in opts:
opts['learn_in_w'] = False # opts['learn_in_w'] = False
if 'output_size' not in opts: # if 'output_size' not in opts:
opts['output_size'] = 1024 # opts['output_size'] = 1024
net = pSp(Namespace(**opts)) # net = pSp(Namespace(**opts))
net.eval() # net.eval()
net.cuda() # net.cuda()
return net # return net
class GanAttack(nn.Module): class GanAttack(nn.Module):
def __init__(self, cfg,prompt): def __init__(self, cfg,prompt):
super().__init__() super().__init__()
self.net= get_stylegan_inverter(cfg) # self.net= get_stylegan_inverter(cfg)
# self.generator.eval() # self.generator.eval()
# self.inverter=inverter # self.inverter=inverter
# self.images_resize=images_resize # self.images_resize=images_resize
self.generator=get_stylegan_generator(cfg)
self.prompt=prompt self.prompt=prompt
text_len=self.prompt.shape[1] text_len=self.prompt.shape[1]
self.mlp=nn.Sequential( self.mlp=nn.Sequential(
@ -62,24 +75,15 @@ class GanAttack(nn.Module):
def forward(self, img): def forward(self, img,basecode,detailcode):
codes = self.net.encoder(img)
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
# result_images = self.net.face_pool(result_images)
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=codes
batch_size=img.shape[0] batch_size=img.shape[0]
# print(img.shape)
# print(self.prompt.shape)
# print(codes.shape)
prompt=self.prompt.repeat(batch_size,18,1).to(device) prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape) # print(prompt.shape)
x_prompt=torch.cat([codes,prompt],dim=2) x_prompt=torch.cat([basecode,prompt],dim=2)
x_prompt=self.mlp(x_prompt) x_prompt=self.mlp(x_prompt)
x=x_prompt+x x=x_prompt+x
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False) result_images=self.generator.synthesis(detailcode,randomize_noise=False,basecode=x)['image']
result_images = self.net.face_pool(im)
return result_images,x return result_images,x

File diff suppressed because it is too large Load Diff

18
stylegan/sync_op.py Normal file
View File

@ -0,0 +1,18 @@
# python3.7
"""Contains the synchronizing operator."""
import torch
import torch.distributed as dist
__all__ = ['all_gather']
def all_gather(tensor):
"""Gathers tensor from all devices and does averaging."""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, async_op=False)
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)

View File

@ -1,7 +1,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from PIL import Image
from torchvision import models from torchvision import models
import numpy as np
def get_model(config): def get_model(config):
if config.dataset == 'gender_dataset': if config.dataset == 'gender_dataset':
@ -80,4 +82,58 @@ def set_requires_grad( nets, requires_grad=False):
for net in nets: for net in nets:
if net is not None: if net is not None:
for param in net.parameters(): for param in net.parameters():
param.requires_grad = requires_grad param.requires_grad = requires_grad
def preprocess(images, channel_order='RGB'):
"""Preprocesses the input images if needed.
This function assumes the input numpy array is with shape [batch_size,
height, width, channel]. Here, `channel = 3` for color image and
`channel = 1` for grayscale image. The returned images are with shape
[batch_size, channel, height, width].
NOTE: The channel order of input images is always assumed as `RGB`.
Args:
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
Returns:
The preprocessed images with dtype `numpy.float32` and range
[-1, 1].
"""
# input : numpy, np.uint8, 0~255, RGB, BHWC
# output : numpy, np.float32, -1~1, RGB, BCHW
image_channels = 3
max_val = 1.0
min_val = -1.0
if image_channels == 3 and channel_order == 'BGR':
images = images[:, :, :, ::-1]
images = images / 255.0 * (max_val - min_val) + min_val
images = images.astype(np.float32).transpose(0, 3, 1, 2)
return images
def postprocess(images):
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
# input : tensor, -1~1, RGB, BCHW
# output : np.uint8, 0~255, BGR, BHWC
images = images.detach().cpu().numpy()
images = (images + 1.) * 255. / 2.
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
return images
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
# input : -1~1, RGB, BCHW, Tensor
# output : -1~1, RGB, BCHW, Tensor
image_target_resized = image_target.clone().cpu().numpy()
image_target_resized = (image_target_resized + 1.) * 255. / 2.
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
tmps = []
for i in range(image_target_resized.shape[0]):
tmp = image_target_resized[i]
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
tmps.append(tmp)
return torch.cat(tmps, dim=0)