update to style

This commit is contained in:
Li Wenyun 2023-12-12 17:36:08 +08:00
parent ef7d4347e3
commit e9fe02f929
6 changed files with 1125 additions and 38 deletions

View File

@ -1,6 +1,5 @@
import sys
import os
sys.path.append('./pixel2style2pixel')
import torch
import torch.nn as nn
import torch.optim as optim
@ -74,15 +73,18 @@ def main(cfg: DictConfig) -> None:
running_loss = 0
running_corrects = 0
for i, (inputs, labels) in enumerate(train_dataloader):
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
codes = model.net.encoder(inputs)
base_code=base_code.to(device)
detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
optimizer.zero_grad()
generated_img,adv_latent_codes=model(inputs)
# generated_img,adv_latent_codes=model(inputs)
# loss_vgg=vgg_loss(inputs,generated_img)
loss_l1=F.l1_loss(codes,adv_latent_codes)
loss_l1=F.l1_loss(base_code,adv_latent_codes)
loss_clip=clip_loss(generated_img,prompt)
adv_outputs = classifier(generated_img)
clean_outputs=classifier(inputs)
@ -112,11 +114,13 @@ def main(cfg: DictConfig) -> None:
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels) in enumerate(test_dataloader):
for i, (inputs, labels,base_code,detail_code) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
base_code=base_code.to(device)
detail_code=detail_code.to(device)
generated_img,adv_latent_codes=model(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)

View File

@ -19,6 +19,7 @@ paths:
adv_embedding: pretrained_models
prompt: red lipstick
resolution: 1024
# available attributes
# ['Blond_Hair', 'Wavy_Hair', 'Young', 'Eyeglasses', 'Heavy_Makeup', 'Rosy_Cheeks',
# 'Chubby', 'Mouth_Slightly_Open', 'Bushy_Eyebrows', 'Wearing_Lipstick', 'Smiling',

View File

@ -5,10 +5,10 @@ import torch.nn as nn
import torch.nn.functional as F
import os
import clip
from pixel2style2pixel.models.psp import pSp
# from pixel2style2pixel.models.psp import pSp
from argparse import Namespace
from utils import normalize
from stylegan.stylegan2_generator import StyleGAN2Generator
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
@ -25,33 +25,46 @@ def get_prompt(cfg):
with torch.no_grad():
prompt = model.encode_text(text)
return prompt
def get_stylegan_generator(cfg):
# model, preprocess = clip.load("RN50", device=device)
resolution=cfg.resolution
generator=StyleGAN2Generator(resolution=resolution)
checkpoint = torch.load(cfg.paths.stylegan, map_location=device)
generator.load_state_dict(checkpoint['generator'])
generator.to(device)
generator.eval()
return generator
# class GanAttack(nn.Module):
def get_stylegan_inverter(cfg):
# def get_stylegan_inverter(cfg):
# ensure_checkpoint_exists(ckpt_path)
path=cfg.paths.inverter_cfg
ckpt = torch.load(path, map_location='cuda:0')
opts = ckpt['opts']
opts['checkpoint_path'] = path
if 'learn_in_w' not in opts:
opts['learn_in_w'] = False
if 'output_size' not in opts:
opts['output_size'] = 1024
# # ensure_checkpoint_exists(ckpt_path)
# path=cfg.paths.inverter_cfg
# ckpt = torch.load(path, map_location='cuda:0')
# opts = ckpt['opts']
# opts['checkpoint_path'] = path
# if 'learn_in_w' not in opts:
# opts['learn_in_w'] = False
# if 'output_size' not in opts:
# opts['output_size'] = 1024
net = pSp(Namespace(**opts))
net.eval()
net.cuda()
return net
# net = pSp(Namespace(**opts))
# net.eval()
# net.cuda()
# return net
class GanAttack(nn.Module):
def __init__(self, cfg,prompt):
super().__init__()
self.net= get_stylegan_inverter(cfg)
# self.net= get_stylegan_inverter(cfg)
# self.generator.eval()
# self.inverter=inverter
# self.images_resize=images_resize
self.generator=get_stylegan_generator(cfg)
self.prompt=prompt
text_len=self.prompt.shape[1]
self.mlp=nn.Sequential(
@ -62,24 +75,15 @@ class GanAttack(nn.Module):
def forward(self, img):
codes = self.net.encoder(img)
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
# result_images, result_latent = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=False)
# result_images = self.net.face_pool(result_images)
# _, _, _, refine_images, latent_codes, _=self.inverter(img,self.images_resize,img_path)
x=codes
def forward(self, img,basecode,detailcode):
batch_size=img.shape[0]
# print(img.shape)
# print(self.prompt.shape)
# print(codes.shape)
prompt=self.prompt.repeat(batch_size,18,1).to(device)
# print(prompt.shape)
x_prompt=torch.cat([codes,prompt],dim=2)
x_prompt=torch.cat([basecode,prompt],dim=2)
x_prompt=self.mlp(x_prompt)
x=x_prompt+x
im,_=self.net.decoder(x, input_is_latent=True, randomize_noise=False, return_latents=False)
result_images = self.net.face_pool(im)
result_images=self.generator.synthesis(detailcode,randomize_noise=False,basecode=x)['image']
return result_images,x

File diff suppressed because it is too large Load Diff

18
stylegan/sync_op.py Normal file
View File

@ -0,0 +1,18 @@
# python3.7
"""Contains the synchronizing operator."""
import torch
import torch.distributed as dist
__all__ = ['all_gather']
def all_gather(tensor):
"""Gathers tensor from all devices and does averaging."""
if not dist.is_initialized():
return tensor
world_size = dist.get_world_size()
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, async_op=False)
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)

View File

@ -1,7 +1,9 @@
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torchvision import models
import numpy as np
def get_model(config):
if config.dataset == 'gender_dataset':
@ -81,3 +83,57 @@ def set_requires_grad( nets, requires_grad=False):
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def preprocess(images, channel_order='RGB'):
"""Preprocesses the input images if needed.
This function assumes the input numpy array is with shape [batch_size,
height, width, channel]. Here, `channel = 3` for color image and
`channel = 1` for grayscale image. The returned images are with shape
[batch_size, channel, height, width].
NOTE: The channel order of input images is always assumed as `RGB`.
Args:
images: The raw inputs with dtype `numpy.uint8` and range [0, 255].
Returns:
The preprocessed images with dtype `numpy.float32` and range
[-1, 1].
"""
# input : numpy, np.uint8, 0~255, RGB, BHWC
# output : numpy, np.float32, -1~1, RGB, BCHW
image_channels = 3
max_val = 1.0
min_val = -1.0
if image_channels == 3 and channel_order == 'BGR':
images = images[:, :, :, ::-1]
images = images / 255.0 * (max_val - min_val) + min_val
images = images.astype(np.float32).transpose(0, 3, 1, 2)
return images
def postprocess(images):
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`."""
# input : tensor, -1~1, RGB, BCHW
# output : np.uint8, 0~255, BGR, BHWC
images = images.detach().cpu().numpy()
images = (images + 1.) * 255. / 2.
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
images = images.transpose(0, 2, 3, 1)[:,:,:,[2,1,0]]
return images
def Lanczos_resizing(image_target, resizing_tuple=(256,256)):
# input : -1~1, RGB, BCHW, Tensor
# output : -1~1, RGB, BCHW, Tensor
image_target_resized = image_target.clone().cpu().numpy()
image_target_resized = (image_target_resized + 1.) * 255. / 2.
image_target_resized = np.clip(image_target_resized + 0.5, 0, 255).astype(np.uint8)
image_target_resized = image_target_resized.transpose(0, 2, 3, 1)
tmps = []
for i in range(image_target_resized.shape[0]):
tmp = image_target_resized[i]
tmp = Image.fromarray(tmp) # PIL, 0~255, uint8, RGB, HWC
tmp = np.array(tmp.resize(resizing_tuple, PIL.Image.LANCZOS))
tmp = torch.from_numpy(preprocess(tmp[np.newaxis,:])).cuda()
tmps.append(tmp)
return torch.cat(tmps, dim=0)