class ImageDataset(Dataset): def __init__(self, data_path, mode, transform=None): self.path=data_path data_dir=pathlib.Path(data_path) self.mode=mode self.transform=transform self.image_path=list(data_dir.glob("NWPU-RESISC45/*/*")) self.image_path=[str(path) for path in self.image_path] random.seed(1) random.shuffle(self.image_path) if self.mode == 'database': self.image_path=self.image_path[12000:] elif self.mode == 'test': self.image_path=self.image_path[7000:12000] else: self.image_path=self.image_path[:7000] lable_names = sorted(item.name for item in data_dir.glob("NWPU-RESISC45/*/")) lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] def __getitem__(self, index): img = Image.open(os.path.join(self.path, self.image_path[index])) img = img.convert('RGB') if self.transform is not None: img = self.transform(img) label = torch.LongTensor([self.image_label[index]]) return img, label, index def __len__(self): return len(self.image_path)