Stylesttack/a.py

31 lines
1.3 KiB
Python

class ImageDataset(Dataset):
def __init__(self, data_path, mode, transform=None):
self.path=data_path
data_dir=pathlib.Path(data_path)
self.mode=mode
self.transform=transform
self.image_path=list(data_dir.glob("NWPU-RESISC45/*/*"))
self.image_path=[str(path) for path in self.image_path]
random.seed(1)
random.shuffle(self.image_path)
if self.mode == 'database':
self.image_path=self.image_path[12000:]
elif self.mode == 'test':
self.image_path=self.image_path[7000:12000]
else:
self.image_path=self.image_path[:7000]
lable_names = sorted(item.name for item in data_dir.glob("NWPU-RESISC45/*/"))
lable_to_index = dict((name, index) for index, name in enumerate(lable_names))
self.image_label=[lable_to_index[pathlib.Path(path).parent.name] for path in self.image_path]
def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([self.image_label[index]])
return img, label, index
def __len__(self):
return len(self.image_path)