31 lines
1.3 KiB
Python
31 lines
1.3 KiB
Python
class ImageDataset(Dataset):
|
|
def __init__(self, data_path, mode, transform=None):
|
|
self.path=data_path
|
|
data_dir=pathlib.Path(data_path)
|
|
self.mode=mode
|
|
self.transform=transform
|
|
self.image_path=list(data_dir.glob("NWPU-RESISC45/*/*"))
|
|
self.image_path=[str(path) for path in self.image_path]
|
|
random.seed(1)
|
|
random.shuffle(self.image_path)
|
|
if self.mode == 'database':
|
|
self.image_path=self.image_path[12000:]
|
|
elif self.mode == 'test':
|
|
self.image_path=self.image_path[7000:12000]
|
|
else:
|
|
self.image_path=self.image_path[:7000]
|
|
lable_names = sorted(item.name for item in data_dir.glob("NWPU-RESISC45/*/"))
|
|
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
|
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
|
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
|
img = img.convert('RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
label = torch.LongTensor([self.image_label[index]])
|
|
return img, label, index
|
|
|
|
def __len__(self):
|
|
return len(self.image_path) |