GanAttack/data/dataset.py

80 lines
3.1 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn
import pathlib
import os
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
class ImageDataset(Dataset):
def __init__(self, data_path, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.mode=mode
self.transform=transform
if self.mode == 'train':
self.image_path=list(data_dir.glob("train/*/*"))
self.image_path=[str(path) for path in self.image_path]
else:
self.image_path=list(data_dir.glob("test/*/*"))
self.image_path=[str(path) for path in self.image_path]
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
image_path=self.image_path[index]
return img, image_path ,label
def __len__(self):
return len(self.image_path)
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset
def get_adv_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = ImageDataset(path,'train',transforms_train)
test_dataset= ImageDataset(path,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset