添加 a.py
This commit is contained in:
parent
9b3e14b3c9
commit
bd2dc4a7cc
|
|
@ -0,0 +1,31 @@
|
||||||
|
class ImageDataset(Dataset):
|
||||||
|
def __init__(self, data_path, mode, transform=None):
|
||||||
|
self.path=data_path
|
||||||
|
data_dir=pathlib.Path(data_path)
|
||||||
|
self.mode=mode
|
||||||
|
self.transform=transform
|
||||||
|
self.image_path=list(data_dir.glob("NWPU-RESISC45/*/*"))
|
||||||
|
self.image_path=[str(path) for path in self.image_path]
|
||||||
|
random.seed(1)
|
||||||
|
random.shuffle(self.image_path)
|
||||||
|
if self.mode == 'database':
|
||||||
|
self.image_path=self.image_path[12000:]
|
||||||
|
elif self.mode == 'test':
|
||||||
|
self.image_path=self.image_path[7000:12000]
|
||||||
|
else:
|
||||||
|
self.image_path=self.image_path[:7000]
|
||||||
|
lable_names = sorted(item.name for item in data_dir.glob("NWPU-RESISC45/*/"))
|
||||||
|
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
|
||||||
|
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
||||||
|
img = img.convert('RGB')
|
||||||
|
if self.transform is not None:
|
||||||
|
img = self.transform(img)
|
||||||
|
label = torch.LongTensor([self.image_label[index]])
|
||||||
|
return img, label, index
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_path)
|
||||||
Loading…
Reference in New Issue