41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
from abc import abstractmethod
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
class TransformsConfig(object):
|
|
|
|
def __init__(self, opts):
|
|
self.opts = opts
|
|
|
|
@abstractmethod
|
|
def get_transforms(self):
|
|
pass
|
|
|
|
|
|
class EncodeTransforms(TransformsConfig):
|
|
|
|
def __init__(self, opts):
|
|
super(EncodeTransforms, self).__init__(opts)
|
|
|
|
def get_transforms(self):
|
|
transforms_dict = {
|
|
'transform_gt_train': transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.RandomHorizontalFlip(0.5),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
|
'transform_source': None,
|
|
'transform_test': transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
|
'transform_inference': transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
|
'transform_apply':transforms.Compose([
|
|
# transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
|
}
|
|
return transforms_dict |