From bd2dc4a7cc14ee23c84916c3ecdb95dd189b67a7 Mon Sep 17 00:00:00 2001 From: liwenyun Date: Sun, 3 Dec 2023 16:21:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20a.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- a.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 a.py diff --git a/a.py b/a.py new file mode 100644 index 0000000..3d8667d --- /dev/null +++ b/a.py @@ -0,0 +1,31 @@ +class ImageDataset(Dataset): + def __init__(self, data_path, mode, transform=None): + self.path=data_path + data_dir=pathlib.Path(data_path) + self.mode=mode + self.transform=transform + self.image_path=list(data_dir.glob("NWPU-RESISC45/*/*")) + self.image_path=[str(path) for path in self.image_path] + random.seed(1) + random.shuffle(self.image_path) + if self.mode == 'database': + self.image_path=self.image_path[12000:] + elif self.mode == 'test': + self.image_path=self.image_path[7000:12000] + else: + self.image_path=self.image_path[:7000] + lable_names = sorted(item.name for item in data_dir.glob("NWPU-RESISC45/*/")) + lable_to_index = dict((name, index) for index, name in enumerate(lable_names)) + self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path] + + + def __getitem__(self, index): + img = Image.open(os.path.join(self.path, self.image_path[index])) + img = img.convert('RGB') + if self.transform is not None: + img = self.transform(img) + label = torch.LongTensor([self.image_label[index]]) + return img, label, index + + def __len__(self): + return len(self.image_path) \ No newline at end of file