From 08eee04d1ec0820de06c68deba429b502492143d Mon Sep 17 00:00:00 2001 From: liwenyun Date: Tue, 12 Dec 2023 16:40:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20data/dataset.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/dataset.py | 168 +++++++++++++++++++++++++----------------------- 1 file changed, 89 insertions(+), 79 deletions(-) diff --git a/data/dataset.py b/data/dataset.py index 375a211..5cd16c4 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,80 +1,90 @@ -import torch -import torchvision -from torchvision import datasets, models, transforms -from torch.utils.data import Dataset -from PIL import Image -import torch.nn as nn -import pathlib -import os - -transforms_train = transforms.Compose([ - transforms.Resize((256, 256)), - transforms.RandomHorizontalFlip(), # data augmentation - transforms.ToTensor(), - transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization -]) - -transforms_test = transforms.Compose([ - transforms.Resize((256, 256)), - transforms.ToTensor(), - transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) -]) - -class ImageDataset(Dataset): - def __init__(self, data_path, mode, transform=None): - self.path=data_path - data_dir=pathlib.Path(data_path) - self.mode=mode - self.transform=transform - if self.mode == 'train': - self.image_path=list(data_dir.glob("train/*/*")) - self.image_path=[str(path) for path in self.image_path] - else: - self.image_path=list(data_dir.glob("test/*/*")) - self.image_path=[str(path) for path in self.image_path] - - lable_names = sorted(item.name for item in data_dir.glob("train/*/")) - lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) - self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] - - - def __getitem__(self, index): - img = Image.open(os.path.join(self.path, self.image_path[index])) - img = img.convert('RGB') - if self.transform is not None: - img = self.transform(img) - label = torch.LongTensor([self.image_label[index]]) - image_path=self.image_path[index] - return img, image_path ,label - - def __len__(self): - return len(self.image_path) - -def get_dataset(config): - if config.dataset == 'gender_dataset': - path=config.paths.gender_dataset - else: - path=config.paths.identity_dataset - train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) - test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) - - train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers) - return train_dataloader,test_dataloader,train_dataset,test_dataset - -def get_adv_dataset(config): - if config.dataset == 'gender_dataset': - path=config.paths.gender_dataset - else: - path=config.paths.identity_dataset - train_dataset = ImageDataset(path,'train',transforms_train) - test_dataset= ImageDataset(path,'test',transforms_test) - - train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) - test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) - - return train_dataloader,test_dataloader,train_dataset,test_dataset - - - +import torch +import torchvision +from torchvision import datasets, models, transforms +from torch.utils.data import Dataset +from PIL import Image +import torch.nn as nn +import pathlib +import os +import numpy as np + +transforms_train = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomHorizontalFlip(), # data augmentation + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization +]) + +transforms_test = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) +]) + +class ImageDataset(Dataset): + def __init__(self, data_path, base_code,detail_code, mode, transform=None): + self.path=data_path + data_dir=pathlib.Path(data_path) + self.base_dir=base_code + self.detail_dir=detail_code + self.mode=mode + self.transform=transform + if self.mode == 'train': + self.image_path=list(data_dir.glob("train/*/*")) + self.image_path=[str(path) for path in self.image_path] + else: + self.image_path=list(data_dir.glob("test/*/*")) + self.image_path=[str(path) for path in self.image_path] + + lable_names = sorted(item.name for item in data_dir.glob("train/*/")) + lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) + self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] + + + def __getitem__(self, index): + img = Image.open(os.path.join(self.path, self.image_path[index])) + img = img.convert('RGB') + base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy'))) + detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy'))) + base_code=torch.from_numpy(base_code) + detail_code=torch.from_numpy(detail_code) + if self.transform is not None: + img = self.transform(img) + label = torch.LongTensor([self.image_label[index]]) + # image_path=self.image_path[index] + return img, label,base_code,detail_code + + def __len__(self): + return len(self.image_path) + +def get_dataset(config): + if config.dataset == 'gender_dataset': + path=config.paths.gender_dataset + else: + path=config.paths.identity_dataset + + train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) + test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers) + return train_dataloader,test_dataloader,train_dataset,test_dataset + +def get_adv_dataset(config): + if config.dataset == 'gender_dataset': + path=config.paths.gender_dataset + else: + path=config.paths.identity_dataset + base_code=config.paths.base_code + detail_code=config.paths.detail_code + train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train) + test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test) + + train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) + test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) + + return train_dataloader,test_dataloader,train_dataset,test_dataset + + + \ No newline at end of file