更新 data/dataset.py

This commit is contained in:
liwenyun 2023-12-12 16:40:38 +08:00
parent c10ee107cc
commit 08eee04d1e
1 changed files with 89 additions and 79 deletions

View File

@ -6,6 +6,7 @@ from PIL import Image
import torch.nn as nn
import pathlib
import os
import numpy as np
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
@ -21,9 +22,11 @@ transforms_test = transforms.Compose([
])
class ImageDataset(Dataset):
def __init__(self, data_path, mode, transform=None):
def __init__(self, data_path, base_code,detail_code, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.base_dir=base_code
self.detail_dir=detail_code
self.mode=mode
self.transform=transform
if self.mode == 'train':
@ -41,11 +44,15 @@ class ImageDataset(Dataset):
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
base_code=torch.from_numpy(base_code)
detail_code=torch.from_numpy(detail_code)
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
image_path=self.image_path[index]
return img, image_path ,label
# image_path=self.image_path[index]
return img, label,base_code,detail_code
def __len__(self):
return len(self.image_path)
@ -55,6 +62,7 @@ def get_dataset(config):
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
@ -67,11 +75,13 @@ def get_adv_dataset(config):
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = ImageDataset(path,'train',transforms_train)
test_dataset= ImageDataset(path,'test',transforms_test)
base_code=config.paths.base_code
detail_code=config.paths.detail_code
train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train)
test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader,train_dataset,test_dataset