31 lines
1.2 KiB
Python
31 lines
1.2 KiB
Python
import torch
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
import torch.nn as nn
|
|
import os
|
|
|
|
transforms_train = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.RandomHorizontalFlip(), # data augmentation
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
|
])
|
|
|
|
transforms_test = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
|
])
|
|
|
|
|
|
def get_dataset(config):
|
|
if config.dataset == 'gender_dataset':
|
|
path=config.paths.gender_dataset
|
|
else:
|
|
path=config.paths.identity_dataset
|
|
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
|
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
|
return train_dataloader,test_dataloader |