import torch import torchvision from torchvision import datasets, models, transforms from torch.utils.data import Dataset from PIL import Image import torch.nn as nn import pathlib import os import numpy as np transforms_train = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), # data augmentation transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization ]) transforms_test = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) class ImageDataset(Dataset): def __init__(self, data_path, detail_code, mode, transform=None): self.path=data_path data_dir=pathlib.Path(data_path) self.detail_dir=detail_code self.mode=mode self.transform=transform if self.mode == 'train': self.image_path=list(data_dir.glob("train/*/*")) self.image_path=[str(path) for path in self.image_path] else: self.image_path=list(data_dir.glob("test/*/*")) self.image_path=[str(path) for path in self.image_path] lable_names = sorted(item.name for item in data_dir.glob("train/*/")) lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] def __getitem__(self, index): img = Image.open(os.path.join(self.path, self.image_path[index])) img = img.convert('RGB') # base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy'))) detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy'))) # base_code=torch.from_numpy(base_code) detail_code=torch.from_numpy(detail_code) if self.transform is not None: img = self.transform(img) label = torch.LongTensor([self.image_label[index]]) # image_path=self.image_path[index] return img, label,detail_code def __len__(self): return len(self.image_path) def get_dataset(config): if config.dataset == 'gender_dataset': path=config.paths.gender_dataset else: path=config.paths.identity_dataset train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers) return train_dataloader,test_dataloader,train_dataset,test_dataset def get_adv_dataset(config): if config.dataset == 'gender_dataset': path=config.paths.gender_dataset else: path=config.paths.identity_dataset # base_code=config.paths.base_code detail_code=config.paths.detail_code train_dataset = ImageDataset(path,detail_code,'train',transforms_train) test_dataset= ImageDataset(path,detail_code,'test',transforms_test) train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) return train_dataloader,test_dataloader,train_dataset,test_dataset