549 lines
25 KiB
Python
549 lines
25 KiB
Python
import math
|
|
import os
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
|
|
from models.encoder import Encoder
|
|
from models.stylegan2.model import Generator
|
|
|
|
matplotlib.use('Agg')
|
|
from models.latent_codes_pool import LatentCodesPool
|
|
from models.discriminator import LatentCodesDiscriminator
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn.functional as F
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from utils.ranger import Ranger
|
|
from torch import autograd
|
|
from utils.train_utils import get_train_progressive_stage, requires_grad
|
|
from utils import common, train_utils
|
|
from utils.train_utils import load_train_checkpoint
|
|
from criteria import id_loss, w_norm, moco_loss
|
|
from configs import transforms_config
|
|
from datasets.images_dataset import ImagesDataset
|
|
from criteria.lpips.lpips import LPIPS
|
|
from loguru import logger
|
|
|
|
|
|
class EncoderTrainer:
|
|
train_dataset = None
|
|
test_dataset = None
|
|
train_dataloader = None
|
|
test_dataloader = None
|
|
optimizer = None
|
|
discriminator_optimizer = None
|
|
mse_loss = None
|
|
lpips_loss = None
|
|
id_loss = None
|
|
w_norm_loss = None
|
|
moco_loss = None
|
|
|
|
def __init__(self, opts):
|
|
self.opts = opts
|
|
self.global_step = opts.start_step
|
|
self.device = 'cuda'
|
|
self.opts.device = self.device
|
|
self.opts.n_styles = int(math.log(opts.resolution, 2)) * 2 - 2
|
|
|
|
# resume from checkpoint
|
|
checkpoint = load_train_checkpoint(opts)
|
|
|
|
# initialize encoder and decoder
|
|
latent_avg = None
|
|
self.decoder = Generator(opts.resolution, 512, 8).to(self.device)
|
|
self.decoder.train()
|
|
if checkpoint is not None:
|
|
self.load_from_train_checkpoint(checkpoint)
|
|
else:
|
|
decoder_checkpoint = torch.load(opts.stylegan_weights, map_location='cpu')
|
|
self.decoder.load_state_dict(decoder_checkpoint['g_ema'])
|
|
latent_avg = decoder_checkpoint['latent_avg']
|
|
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)).to(self.device)
|
|
if latent_avg is None:
|
|
latent_avg = self.decoder.mean_latent(int(1e5))[0].detach() if checkpoint is None else None
|
|
self.encoder = Encoder(opts, checkpoint, latent_avg, device=self.device).to(self.device)
|
|
|
|
# initialize discriminator
|
|
if self.opts.w_discriminator_lambda > 0:
|
|
dims = 512
|
|
self.discriminator = LatentCodesDiscriminator(dims, 4).to(self.device)
|
|
if opts.dist:
|
|
self.discriminator = DistributedDataParallel(
|
|
self.discriminator,
|
|
device_ids=[torch.cuda.current_device()])
|
|
self.real_w_pool = LatentCodesPool(opts.w_pool_size)
|
|
self.fake_w_pool = LatentCodesPool(opts.w_pool_size)
|
|
|
|
# initialize sncd
|
|
if self.opts.sncd_lambda > 0:
|
|
self.anchor_codes = []
|
|
with torch.no_grad():
|
|
w = self.decoder.w_sample(int(1e5))
|
|
s_plus = self.decoder.get_style_space(w, split=True)
|
|
for s in s_plus:
|
|
self.anchor_codes.append((s / s.norm(2, dim=-1, keepdim=True)).mean(dim=0, keepdim=True))
|
|
|
|
self.configure_loss()
|
|
|
|
self.configure_datasets()
|
|
|
|
# Initialize logger
|
|
self.log_dir = os.path.join(opts.exp_dir, 'logs')
|
|
if opts.rank == 0:
|
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
if self.opts.use_wandb:
|
|
from utils.wandb_utils import WBLogger
|
|
self.wb_logger = WBLogger(self.opts)
|
|
|
|
# initialize checkpoint dir
|
|
self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
|
|
if opts.rank == 0:
|
|
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
|
self.best_val_loss = None
|
|
if self.opts.save_interval is None:
|
|
self.opts.save_interval = self.opts.max_steps
|
|
|
|
self.configure_optimizers(checkpoint)
|
|
self.progressive_stage = get_train_progressive_stage(self.opts.progressive_steps, self.global_step)
|
|
|
|
def configure_datasets(self):
|
|
transforms_dict = transforms_config.EncodeTransforms(self.opts).get_transforms()
|
|
self.train_dataset = train_dataset = ImagesDataset(source_root=self.opts.train_dataset_path,
|
|
target_root=self.opts.train_dataset_path,
|
|
source_transform=transforms_dict['transform_source'],
|
|
target_transform=transforms_dict['transform_gt_train'],
|
|
opts=self.opts)
|
|
self.test_dataset = test_dataset = ImagesDataset(source_root=self.opts.test_dataset_path,
|
|
target_root=self.opts.test_dataset_path,
|
|
source_transform=transforms_dict['transform_source'],
|
|
target_transform=transforms_dict['transform_test'],
|
|
opts=self.opts)
|
|
|
|
# set dataloader
|
|
train_batch_size = self.opts.batch_size // self.opts.gpu_num
|
|
test_batch_size = self.opts.test_batch_size // self.opts.gpu_num
|
|
assert self.opts.batch_size == train_batch_size * self.opts.gpu_num, 'Train batch size is not a multiple of gpu num.'
|
|
assert self.opts.test_batch_size == test_batch_size * self.opts.gpu_num, 'Test batch size is not a multiple of gpu num.'
|
|
if self.opts.dist:
|
|
train_sampler = DistributedSampler(
|
|
self.train_dataset,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
seed=self.opts.seed
|
|
)
|
|
self.train_dataloader = DataLoader(
|
|
self.train_dataset,
|
|
sampler=train_sampler,
|
|
batch_size=train_batch_size,
|
|
num_workers=int(self.opts.workers // self.opts.gpu_num),
|
|
)
|
|
test_sampler = DistributedSampler(
|
|
self.test_dataset,
|
|
shuffle=False,
|
|
drop_last=False,
|
|
seed=self.opts.seed
|
|
)
|
|
self.test_dataloader = DataLoader(
|
|
self.test_dataset,
|
|
sampler=test_sampler,
|
|
batch_size=test_batch_size,
|
|
num_workers=int(self.opts.test_workers // self.opts.gpu_num)
|
|
)
|
|
else:
|
|
self.train_dataloader = DataLoader(self.train_dataset,
|
|
batch_size=train_batch_size,
|
|
shuffle=True,
|
|
num_workers=int(self.opts.workers),
|
|
drop_last=True)
|
|
self.test_dataloader = DataLoader(self.test_dataset,
|
|
batch_size=test_batch_size,
|
|
shuffle=False,
|
|
num_workers=int(self.opts.test_workers),
|
|
drop_last=True)
|
|
if self.opts.rank == 0:
|
|
logger.info(f"Number of train samples: {len(train_dataset)}, train_batch_size per GPU: {train_batch_size}.")
|
|
logger.info(f"Number of test samples: {len(test_dataset)}, test_batch_size per GPU: {test_batch_size}.")
|
|
|
|
def configure_loss(self):
|
|
self.mse_loss = nn.MSELoss().to(self.device).eval()
|
|
if self.opts.lpips_lambda > 0:
|
|
self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
|
|
if self.opts.id_lambda > 0:
|
|
self.id_loss = id_loss.IDLoss().to(self.device).eval()
|
|
if self.opts.w_norm_lambda > 0:
|
|
self.w_norm_loss = w_norm.WNormLoss(start_from_latent_avg=self.opts.start_from_latent_avg)
|
|
if self.opts.moco_lambda > 0:
|
|
self.moco_loss = moco_loss.MocoLoss().to(self.device).eval()
|
|
|
|
def configure_optimizers(self, checkpoint):
|
|
requires_grad(self.decoder, False)
|
|
betas = (self.opts.optim_beta1, self.opts.optim_beta2)
|
|
if self.opts.optimizer == 'adam':
|
|
optimizer = torch.optim.Adam(self.encoder.parameters(), lr=self.opts.learning_rate,
|
|
weight_decay=self.opts.weight_decay, betas=betas)
|
|
elif self.opts.optimizer == 'adamw':
|
|
optimizer = torch.optim.AdamW(self.encoder.parameters(), lr=self.opts.learning_rate,
|
|
weight_decay=self.opts.weight_decay, betas=betas)
|
|
elif self.opts.optimizer == 'sgd':
|
|
optimizer = torch.optim.SGD(self.encoder.parameters(), lr=self.opts.learning_rate,
|
|
weight_decay=self.opts.weight_decay)
|
|
else:
|
|
optimizer = Ranger(self.encoder.parameters(), lr=self.opts.learning_rate,
|
|
weight_decay=self.opts.weight_decay, betas=betas)
|
|
if checkpoint is not None:
|
|
if 'optimizer' in checkpoint:
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
else:
|
|
logger.warning('Optimizer state dict is not in checkpoint!')
|
|
|
|
if self.opts.w_discriminator_lambda > 0:
|
|
self.discriminator_optimizer = torch.optim.Adam(list(self.discriminator.parameters()),
|
|
lr=self.opts.discriminator_lr)
|
|
if checkpoint is not None:
|
|
if 'discriminator_optimizer_state_dict' in checkpoint:
|
|
self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
|
|
else:
|
|
logger.warning('Discriminator optimizer state dict is not in checkpoint!')
|
|
self.optimizer = optimizer
|
|
|
|
def inverse(self, x):
|
|
codes = self.encoder(x)
|
|
images, result_latent = self.decoder([codes], input_is_latent=True, randomize_noise=True, return_latents=True)
|
|
images = self.face_pool(images)
|
|
return images, result_latent
|
|
|
|
def train(self):
|
|
self.encoder.train()
|
|
while self.global_step < self.opts.max_steps:
|
|
for batch_idx, batch in enumerate(self.train_dataloader):
|
|
self.encoder.set_progressive_stage(self.progressive_stage)
|
|
loss_dict = {}
|
|
if self.is_training_discriminator():
|
|
loss_dict = self.train_discriminator(batch)
|
|
self.progressive_stage = get_train_progressive_stage(self.opts.progressive_steps, self.global_step)
|
|
|
|
x, y = batch
|
|
x, y = x.to(self.device).float(), y.to(self.device).float()
|
|
y_hat, latent = self.inverse(x)
|
|
loss, encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
|
|
loss_dict = {**loss_dict, **encoder_loss_dict}
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
# Logging related
|
|
if self.global_step % self.opts.image_interval == 0 or (
|
|
self.global_step < 1000 and self.global_step % 25 == 0):
|
|
self.parse_and_log_images(id_logs, x, y, y_hat, title='images/train/faces')
|
|
if self.global_step % self.opts.board_interval == 0:
|
|
self.print_metrics(loss_dict, prefix='train')
|
|
self.log_metrics(loss_dict, prefix='train')
|
|
|
|
# Log images of first batch to wandb
|
|
if self.opts.use_wandb and batch_idx == 0 and self.opts.rank == 0:
|
|
self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="train", step=self.global_step,
|
|
opts=self.opts)
|
|
|
|
# Validation related
|
|
val_loss_dict = None
|
|
if ((
|
|
self.global_step % self.opts.val_interval == 0) and self.global_step != 0) or self.global_step == self.opts.max_steps:
|
|
val_loss_dict = self.validate()
|
|
if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
|
|
self.best_val_loss = val_loss_dict['loss']
|
|
self.save_checkpoint(val_loss_dict, is_best=True)
|
|
|
|
if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
|
|
if val_loss_dict is not None:
|
|
self.save_checkpoint(val_loss_dict, is_best=False)
|
|
self.save_checkpoint(val_loss_dict, is_best=False, is_last=True)
|
|
else:
|
|
self.save_checkpoint(loss_dict, is_best=False)
|
|
self.save_checkpoint(loss_dict, is_best=False, is_last=True)
|
|
|
|
if self.global_step == self.opts.max_steps:
|
|
logger.info('OMG, finished training!')
|
|
break
|
|
|
|
self.global_step += 1
|
|
|
|
def validate(self):
|
|
self.encoder.eval()
|
|
agg_loss_dict = []
|
|
for batch_idx, batch in enumerate(self.test_dataloader):
|
|
x, y = batch
|
|
cur_loss_dict = {}
|
|
if self.is_training_discriminator():
|
|
cur_loss_dict = self.validate_discriminator(batch)
|
|
with torch.no_grad():
|
|
x, y = x.to(self.device).float(), y.to(self.device).float()
|
|
y_hat, latent = self.inverse(x)
|
|
loss, cur_encoder_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
|
|
cur_loss_dict = {**cur_loss_dict, **cur_encoder_loss_dict}
|
|
agg_loss_dict.append(cur_loss_dict)
|
|
|
|
# Logging related
|
|
self.parse_and_log_images(id_logs, x, y, y_hat, title='images/test/faces',
|
|
subscript='{:04d}'.format(batch_idx))
|
|
|
|
# Log images of first batch to wandb
|
|
if self.opts.use_wandb and batch_idx == 0 and self.opts.rank == 0:
|
|
self.wb_logger.log_images_to_wandb(x, y, y_hat, id_logs, prefix="test", step=self.global_step,
|
|
opts=self.opts)
|
|
|
|
loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
|
|
self.log_metrics(loss_dict, prefix='test')
|
|
self.print_metrics(loss_dict, prefix='test')
|
|
|
|
self.encoder.train()
|
|
return loss_dict
|
|
|
|
def calc_loss(self, x, y, y_hat, latent):
|
|
loss_dict = {}
|
|
loss = 0.0
|
|
id_logs = None
|
|
if self.is_training_discriminator(): # Adversarial loss
|
|
loss_disc = 0.
|
|
dims_to_discriminate = list(range(self.opts.n_styles))
|
|
|
|
for i in dims_to_discriminate:
|
|
w = latent[:, i, :]
|
|
fake_pred = self.discriminator(w)
|
|
loss_disc += F.softplus(-fake_pred).mean()
|
|
loss_disc /= len(dims_to_discriminate)
|
|
loss_dict['encoder_discriminator_loss'] = float(loss_disc)
|
|
loss += self.opts.w_discriminator_lambda * loss_disc
|
|
|
|
if self.opts.progressive_steps and self.opts.delta_norm_lambda > 0.: # delta regularization loss
|
|
total_delta_loss = 0
|
|
deltas_latent_dims = list(range(self.opts.n_styles))
|
|
|
|
first_w = latent[:, 0, :]
|
|
for i in range(1, self.progressive_stage + 1):
|
|
curr_dim = deltas_latent_dims[i]
|
|
delta = latent[:, curr_dim, :] - first_w
|
|
delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()
|
|
loss_dict[f"delta{i}_loss"] = float(delta_loss)
|
|
total_delta_loss += delta_loss
|
|
loss_dict['total_delta_loss'] = float(total_delta_loss)
|
|
loss += self.opts.delta_norm_lambda * total_delta_loss
|
|
|
|
if self.opts.sncd_lambda > 0: # calculate cos loss though lambda=0
|
|
dims_to_discriminate = list(range(self.opts.n_styles)) if not self.is_progressive_training() else \
|
|
list(range(self.progressive_stage + 1))
|
|
latent_s = self.decoder.get_style_space(latent, split=True)
|
|
latent_s = [s / s.norm(2, dim=-1, keepdim=True) for s in latent_s]
|
|
similarity = [s0 @ s1.T for s0, s1 in zip(latent_s, self.anchor_codes)]
|
|
sncd_loss = 0
|
|
for dim in dims_to_discriminate:
|
|
closs = -similarity[dim].mean()
|
|
loss_dict[f'sncd_loss_{dim}'] = float(closs)
|
|
sncd_loss += closs
|
|
loss_dict[f'total_sncd_loss'] = float(sncd_loss)
|
|
loss += sncd_loss * self.opts.sncd_lambda
|
|
|
|
if self.opts.id_lambda > 0: # Similarity loss
|
|
loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x)
|
|
loss_dict['loss_id'] = float(loss_id)
|
|
loss_dict['id_improve'] = float(sim_improvement)
|
|
loss += loss_id * self.opts.id_lambda
|
|
|
|
if self.opts.l2_lambda > 0:
|
|
loss_l2 = F.mse_loss(y_hat, y)
|
|
loss_dict['loss_l2'] = float(loss_l2)
|
|
loss += loss_l2 * self.opts.l2_lambda
|
|
|
|
if self.opts.lpips_lambda > 0:
|
|
loss_lpips = self.lpips_loss(y_hat, y)
|
|
loss_dict['loss_lpips'] = float(loss_lpips)
|
|
loss += loss_lpips * self.opts.lpips_lambda
|
|
|
|
if self.opts.w_norm_lambda > 0:
|
|
loss_w_norm = self.w_norm_loss(latent, self.latent_avg)
|
|
loss_dict['loss_w_norm'] = float(loss_w_norm)
|
|
loss += loss_w_norm * self.opts.w_norm_lambda
|
|
|
|
if self.opts.moco_lambda > 0:
|
|
loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x)
|
|
loss_dict['loss_moco'] = float(loss_moco)
|
|
loss_dict['id_improve'] = float(sim_improvement)
|
|
loss += loss_moco * self.opts.moco_lambda
|
|
|
|
loss_dict['loss'] = float(loss)
|
|
return loss, loss_dict, id_logs
|
|
|
|
def get_dims_to_discriminate(self):
|
|
return list(range(self.opts.n_styles))[:self.progressive_stage + 1]
|
|
|
|
def is_progressive_training(self):
|
|
return self.opts.progressive_steps is not None
|
|
|
|
def is_training_discriminator(self):
|
|
return self.opts.w_discriminator_lambda > 0
|
|
|
|
@staticmethod
|
|
def discriminator_loss(real_pred, fake_pred, loss_dict):
|
|
real_loss = F.softplus(-real_pred).mean()
|
|
fake_loss = F.softplus(fake_pred).mean()
|
|
|
|
loss_dict['d_real_loss'] = float(real_loss)
|
|
loss_dict['d_fake_loss'] = float(fake_loss)
|
|
|
|
return real_loss + fake_loss
|
|
|
|
@staticmethod
|
|
def discriminator_r1_loss(real_pred, real_w):
|
|
grad_real, = autograd.grad(outputs=real_pred.sum(), inputs=real_w, create_graph=True)
|
|
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
|
|
|
|
return grad_penalty
|
|
|
|
def train_discriminator(self, batch):
|
|
loss_dict = {}
|
|
x, _ = batch
|
|
x = x.to(self.device).float()
|
|
requires_grad(self.discriminator, True)
|
|
|
|
with torch.no_grad():
|
|
real_w, fake_w = self.sample_real_and_fake_latents(x)
|
|
|
|
real_pred = self.discriminator(real_w)
|
|
fake_pred = self.discriminator(fake_w)
|
|
loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
|
|
loss_dict['discriminator_loss'] = float(loss)
|
|
|
|
self.discriminator_optimizer.zero_grad()
|
|
loss.backward()
|
|
self.discriminator_optimizer.step()
|
|
|
|
# r1 regularization
|
|
d_regularize = self.global_step % self.opts.d_reg_every == 0
|
|
if d_regularize:
|
|
real_w = real_w.detach()
|
|
real_w.requires_grad = True
|
|
real_pred = self.discriminator(real_w)
|
|
r1_loss = self.discriminator_r1_loss(real_pred, real_w)
|
|
|
|
self.discriminator.zero_grad()
|
|
r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[0]
|
|
r1_final_loss.backward()
|
|
self.discriminator_optimizer.step()
|
|
loss_dict['discriminator_r1_loss'] = float(r1_final_loss)
|
|
|
|
# Reset to previous state
|
|
requires_grad(self.discriminator, False)
|
|
|
|
return loss_dict
|
|
|
|
def validate_discriminator(self, test_batch):
|
|
with torch.no_grad():
|
|
loss_dict = {}
|
|
x, _ = test_batch
|
|
x = x.to(self.device).float()
|
|
real_w, fake_w = self.sample_real_and_fake_latents(x)
|
|
|
|
real_pred = self.discriminator(real_w)
|
|
fake_pred = self.discriminator(fake_w)
|
|
loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
|
|
loss_dict['discriminator_loss'] = float(loss)
|
|
return loss_dict
|
|
|
|
def sample_real_and_fake_latents(self, x):
|
|
sample_z = torch.randn(self.opts.batch_size, 512, device=self.device)
|
|
real_w = self.decoder.get_latent(sample_z)
|
|
fake_w = self.encoder(x)
|
|
if self.is_progressive_training(): # When progressive training, feed only unique w's
|
|
dims_to_discriminate = self.get_dims_to_discriminate()
|
|
fake_w = fake_w[:, dims_to_discriminate, :]
|
|
if self.opts.use_w_pool:
|
|
real_w = self.real_w_pool.query(real_w)
|
|
fake_w = self.fake_w_pool.query(fake_w)
|
|
if fake_w.ndim == 3:
|
|
fake_w = fake_w[:, 0, :]
|
|
return real_w, fake_w
|
|
|
|
def save_checkpoint(self, loss_dict, is_best, is_last=False):
|
|
if self.opts.rank == 0:
|
|
if is_best:
|
|
save_name = 'best_model.pt'
|
|
elif is_last:
|
|
save_name = 'last.pt'
|
|
else:
|
|
save_name = 'iteration_{}.pt'.format(self.global_step)
|
|
save_dict = self.__get_save_dict(is_last)
|
|
checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
|
|
torch.save(save_dict, checkpoint_path)
|
|
with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
|
|
if is_best:
|
|
f.write(
|
|
'**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss,
|
|
loss_dict))
|
|
else:
|
|
f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
|
|
|
|
def log_metrics(self, metrics_dict, prefix):
|
|
if self.opts.use_wandb and self.opts.rank == 0:
|
|
self.wb_logger.log(prefix, metrics_dict, self.global_step)
|
|
|
|
def print_metrics(self, metrics_dict, prefix):
|
|
if self.opts.rank == 0:
|
|
logger.info('Metrics for {}, step {}'.format(prefix, self.global_step))
|
|
for key, value in metrics_dict.items():
|
|
logger.info('\t{} = {}'.format(key, value))
|
|
|
|
def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=2):
|
|
display_count = min(display_count, y.shape[0])
|
|
if self.opts.rank == 0:
|
|
im_data = []
|
|
for i in range(display_count):
|
|
cur_im_data = {
|
|
'input_face': common.log_input_image(x[i], self.opts),
|
|
'target_face': common.tensor2im(y[i]),
|
|
'output_face': common.tensor2im(y_hat[i]),
|
|
}
|
|
if id_logs is not None:
|
|
for key in id_logs[i]:
|
|
cur_im_data[key] = id_logs[i][key]
|
|
im_data.append(cur_im_data)
|
|
self.log_images(title, im_data=im_data, subscript=subscript)
|
|
|
|
def log_images(self, name, im_data, subscript=None, log_latest=False):
|
|
fig = common.vis_faces(im_data)
|
|
step = self.global_step
|
|
if log_latest:
|
|
step = 0
|
|
if subscript:
|
|
path = os.path.join(self.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
|
|
else:
|
|
path = os.path.join(self.log_dir, name, '{:04d}.jpg'.format(step))
|
|
if self.opts.rank == 0:
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
fig.savefig(path)
|
|
plt.close(fig)
|
|
|
|
def load_from_train_checkpoint(self, ckpt):
|
|
# load training status
|
|
logger.info('Loading previous training data...')
|
|
self.global_step = ckpt.get('global_step', -1) + 1
|
|
self.best_val_loss = ckpt.get('best_val_loss', 0.)
|
|
logger.info(f'Start from step: {self.global_step}')
|
|
|
|
# load stylegan
|
|
self.decoder.load_state_dict(ckpt['decoder'], strict=True)
|
|
|
|
if self.opts.w_discriminator_lambda > 0:
|
|
self.discriminator.load_state_dict(ckpt['discriminator_state_dict'], strict=False)
|
|
|
|
def __get_save_dict(self, is_last):
|
|
save_dict = {'encoder': self.encoder.state_dict(), 'decoder': self.decoder.state_dict(),
|
|
'opts': vars(self.opts), 'global_step': self.global_step}
|
|
if is_last:
|
|
save_dict['optimizer'] = self.optimizer.state_dict()
|
|
save_dict['best_val_loss'] = self.best_val_loss
|
|
if self.opts.w_discriminator_lambda > 0:
|
|
save_dict['discriminator_state_dict'] = self.discriminator.state_dict()
|
|
if is_last:
|
|
save_dict['discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict()
|
|
return save_dict
|