GanAttack/data/dataset.py

31 lines
1.2 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import os
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader