import torch import torchvision from torchvision import datasets, models, transforms from torch.utils.data import Dataset from PIL import Image import torch.nn as nn import pathlib import os transforms_train = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), # data augmentation transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization ]) transforms_test = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) class ImageDataset(Dataset): def __init__(self, data_path, mode, transform=None): self.path=data_path data_dir=pathlib.Path(data_path) self.mode=mode self.transform=transform if self.mode == 'train': self.image_path=list(data_dir.glob("train/*/*")) self.image_path=[str(path) for path in self.image_path] else: self.image_path=list(data_dir.glob("test/*/*")) self.image_path=[str(path) for path in self.image_path] lable_names = sorted(item.name for item in data_dir.glob("train/*/")) lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] def __getitem__(self, index): img = Image.open(os.path.join(self.path, self.image_path[index])) img = img.convert('RGB') if self.transform is not None: img = self.transform(img) label = torch.LongTensor([self.image_label[index]]) image_path=self.image_path[index] return img, image_path ,label def __len__(self): return len(self.image_path) def get_dataset(config): if config.dataset == 'gender_dataset': path=config.paths.gender_dataset else: path=config.paths.identity_dataset train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers) return train_dataloader,test_dataloader,train_dataset,test_dataset def get_adv_dataset(config): if config.dataset == 'gender_dataset': path=config.paths.gender_dataset else: path=config.paths.identity_dataset train_dataset = ImageDataset(path,'train',transforms_train) test_dataset= ImageDataset(path,'test',transforms_test) train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) return train_dataloader,test_dataloader,train_dataset,test_dataset