更新 data/dataset.py

This commit is contained in:
liwenyun 2023-12-12 16:40:38 +08:00
parent c10ee107cc
commit 08eee04d1e
1 changed files with 89 additions and 79 deletions

View File

@ -1,80 +1,90 @@
import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn
import pathlib
import os
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
class ImageDataset(Dataset):
def __init__(self, data_path, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.mode=mode
self.transform=transform
if self.mode == 'train':
self.image_path=list(data_dir.glob("train/*/*"))
self.image_path=[str(path) for path in self.image_path]
else:
self.image_path=list(data_dir.glob("test/*/*"))
self.image_path=[str(path) for path in self.image_path]
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
image_path=self.image_path[index]
return img, image_path ,label
def __len__(self):
return len(self.image_path)
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset
def get_adv_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = ImageDataset(path,'train',transforms_train)
test_dataset= ImageDataset(path,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset
import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset
from PIL import Image
import torch.nn as nn
import pathlib
import os
import numpy as np
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
class ImageDataset(Dataset):
def __init__(self, data_path, base_code,detail_code, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.base_dir=base_code
self.detail_dir=detail_code
self.mode=mode
self.transform=transform
if self.mode == 'train':
self.image_path=list(data_dir.glob("train/*/*"))
self.image_path=[str(path) for path in self.image_path]
else:
self.image_path=list(data_dir.glob("test/*/*"))
self.image_path=[str(path) for path in self.image_path]
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
base_code=torch.from_numpy(base_code)
detail_code=torch.from_numpy(detail_code)
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
# image_path=self.image_path[index]
return img, label,base_code,detail_code
def __len__(self):
return len(self.image_path)
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset
def get_adv_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
base_code=config.paths.base_code
detail_code=config.paths.detail_code
train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train)
test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset