GanAttack/GanInverter/options/base_options.py

63 lines
3.4 KiB
Python

import abc
import yaml
from argparse import ArgumentParser
def str2bool(arg):
ua = str(arg).upper()
if 'TRUE'.startswith(ua):
return True
elif 'FALSE'.startswith(ua):
return False
else:
raise Exception('Error!')
class BaseOptions:
def __init__(self):
self.config_parser = self.parser = ArgumentParser()
self.initialize()
def initialize(self):
self.parser.add_argument('-c', '--config', default='', type=str, nargs='+', metavar='FILE', help='YAML config file '
'specifying default '
'arguments')
self.parser.add_argument('--exp_dir', default='one_shot', type=str, help='Path to experiment output directory')
self.parser.add_argument('--resolution', default=1024, type=int, help='Resolution of generator')
# Data options
self.parser.add_argument('--transform_type', default='encodetransforms', type=str, help='Type of dataset trans')
self.parser.add_argument('--train_dataset_path', default=None, type=str)
self.parser.add_argument('--test_dataset_path', default='../test', type=str)
self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training.')
self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference.')
self.parser.add_argument('--workers', default=0, type=int, help='Number of train dataloader workers.')
self.parser.add_argument('--test_workers', default=0, type=int,
help='Number of test/inference dataloader workers.')
self.parser.add_argument('--auto_resume', default='False', type=str2bool,
help='Whether to automatically resume. During training, the last checkpoint will be '
'used to resume while the best checkpoint will be used in '
'inference/evaluation/editing process.')
# Model options
self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to model checkpoint.')
self.parser.add_argument('--stylegan_weights', default="", type=str, help='Path to StyleGAN model weights.')
self.parser.add_argument('--encoder_type', default='Encoder4Editing', type=str, help='Which encoder to use')
self.parser.add_argument('--start_from_latent_avg', action='store_true', help='Whether to add average latent '
'vector to generate codes from '
'encoder.')
self.parser.add_argument('--learn_in_w', action='store_true', help='Whether to learn in w space instead of w+')
self.parser.add_argument('--layers', default=50, type=int, help='Number of layers of backbone')
def parse(self):
opts = self.config_parser.parse_args()
if opts.config:
for config in opts.config:
with open(config, 'r') as f:
cfg = yaml.safe_load(f)
self.parser.set_defaults(**cfg)
opts = self.parser.parse_args()
return opts