GanAttack/data/dataset.py

90 lines
3.7 KiB
Python

import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn
import pathlib
import os
import numpy as np
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
class ImageDataset(Dataset):
def __init__(self, data_path, detail_code, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.detail_dir=detail_code
self.mode=mode
self.transform=transform
if self.mode == 'train':
self.image_path=list(data_dir.glob("train/*/*"))
self.image_path=[str(path) for path in self.image_path]
else:
self.image_path=list(data_dir.glob("test/*/*"))
self.image_path=[str(path) for path in self.image_path]
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
# base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
# base_code=torch.from_numpy(base_code)
detail_code=torch.from_numpy(detail_code)
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
# image_path=self.image_path[index]
return img, label,detail_code
def __len__(self):
return len(self.image_path)
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset
def get_adv_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
# base_code=config.paths.base_code
detail_code=config.paths.detail_code
train_dataset = ImageDataset(path,detail_code,'train',transforms_train)
test_dataset= ImageDataset(path,detail_code,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset