80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
import torch
|
|
import torchvision
|
|
from torchvision import datasets, models, transforms
|
|
from torch.utils.data import Dataset
|
|
from PIL import Image
|
|
import torch.nn as nn
|
|
import pathlib
|
|
import os
|
|
|
|
transforms_train = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.RandomHorizontalFlip(), # data augmentation
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
|
])
|
|
|
|
transforms_test = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
|
])
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, data_path, mode, transform=None):
|
|
self.path=data_path
|
|
data_dir=pathlib.Path(data_path)
|
|
self.mode=mode
|
|
self.transform=transform
|
|
if self.mode == 'train':
|
|
self.image_path=list(data_dir.glob("train/*/*"))
|
|
self.image_path=[str(path) for path in self.image_path]
|
|
else:
|
|
self.image_path=list(data_dir.glob("test/*/*"))
|
|
self.image_path=[str(path) for path in self.image_path]
|
|
|
|
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
|
|
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
|
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
|
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
|
img = img.convert('RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
label = torch.LongTensor([self.image_label[index]])
|
|
image_path=self.image_path[index]
|
|
return img, image_path ,label
|
|
|
|
def __len__(self):
|
|
return len(self.image_path)
|
|
|
|
def get_dataset(config):
|
|
if config.dataset == 'gender_dataset':
|
|
path=config.paths.gender_dataset
|
|
else:
|
|
path=config.paths.identity_dataset
|
|
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
|
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
|
|
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
|
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
|
|
|
def get_adv_dataset(config):
|
|
if config.dataset == 'gender_dataset':
|
|
path=config.paths.gender_dataset
|
|
else:
|
|
path=config.paths.identity_dataset
|
|
train_dataset = ImageDataset(path,'train',transforms_train)
|
|
test_dataset= ImageDataset(path,'test',transforms_test)
|
|
|
|
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
|
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
|
|
|
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
|
|
|
|
|
|
|
|