更新 data/dataset.py
This commit is contained in:
parent
c10ee107cc
commit
08eee04d1e
168
data/dataset.py
168
data/dataset.py
|
|
@ -1,80 +1,90 @@
|
|||
import torch
|
||||
import torchvision
|
||||
from torchvision import datasets, models, transforms
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
import torch.nn as nn
|
||||
import pathlib
|
||||
import os
|
||||
|
||||
transforms_train = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomHorizontalFlip(), # data augmentation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
||||
])
|
||||
|
||||
transforms_test = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, data_path, mode, transform=None):
|
||||
self.path=data_path
|
||||
data_dir=pathlib.Path(data_path)
|
||||
self.mode=mode
|
||||
self.transform=transform
|
||||
if self.mode == 'train':
|
||||
self.image_path=list(data_dir.glob("train/*/*"))
|
||||
self.image_path=[str(path) for path in self.image_path]
|
||||
else:
|
||||
self.image_path=list(data_dir.glob("test/*/*"))
|
||||
self.image_path=[str(path) for path in self.image_path]
|
||||
|
||||
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
|
||||
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
||||
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
||||
img = img.convert('RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
label = torch.LongTensor([self.image_label[index]])
|
||||
image_path=self.image_path[index]
|
||||
return img, image_path ,label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_path)
|
||||
|
||||
def get_dataset(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
path=config.paths.gender_dataset
|
||||
else:
|
||||
path=config.paths.identity_dataset
|
||||
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
||||
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||
|
||||
def get_adv_dataset(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
path=config.paths.gender_dataset
|
||||
else:
|
||||
path=config.paths.identity_dataset
|
||||
train_dataset = ImageDataset(path,'train',transforms_train)
|
||||
test_dataset= ImageDataset(path,'test',transforms_test)
|
||||
|
||||
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
|
||||
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import datasets, models, transforms
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image
|
||||
import torch.nn as nn
|
||||
import pathlib
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
transforms_train = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomHorizontalFlip(), # data augmentation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
||||
])
|
||||
|
||||
transforms_test = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, data_path, base_code,detail_code, mode, transform=None):
|
||||
self.path=data_path
|
||||
data_dir=pathlib.Path(data_path)
|
||||
self.base_dir=base_code
|
||||
self.detail_dir=detail_code
|
||||
self.mode=mode
|
||||
self.transform=transform
|
||||
if self.mode == 'train':
|
||||
self.image_path=list(data_dir.glob("train/*/*"))
|
||||
self.image_path=[str(path) for path in self.image_path]
|
||||
else:
|
||||
self.image_path=list(data_dir.glob("test/*/*"))
|
||||
self.image_path=[str(path) for path in self.image_path]
|
||||
|
||||
lable_names = sorted(item.name for item in data_dir.glob("train/*/"))
|
||||
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
||||
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
||||
img = img.convert('RGB')
|
||||
base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
|
||||
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
|
||||
base_code=torch.from_numpy(base_code)
|
||||
detail_code=torch.from_numpy(detail_code)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
label = torch.LongTensor([self.image_label[index]])
|
||||
# image_path=self.image_path[index]
|
||||
return img, label,base_code,detail_code
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_path)
|
||||
|
||||
def get_dataset(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
path=config.paths.gender_dataset
|
||||
else:
|
||||
path=config.paths.identity_dataset
|
||||
|
||||
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
||||
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.classifier.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.classifier.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||
|
||||
def get_adv_dataset(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
path=config.paths.gender_dataset
|
||||
else:
|
||||
path=config.paths.identity_dataset
|
||||
base_code=config.paths.base_code
|
||||
detail_code=config.paths.detail_code
|
||||
train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train)
|
||||
test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test)
|
||||
|
||||
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||
|
||||
return train_dataloader,test_dataloader,train_dataset,test_dataset
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue