From 2ca63fd5a85ed10c6d87eb54b5eaed0f737909bc Mon Sep 17 00:00:00 2001 From: Li Wenyun Date: Sat, 2 Dec 2023 22:02:58 +0800 Subject: [PATCH] add dataset --- classifier_training.py | 73 ++++++++++++++++++++++++++++++++++++++++++ config/config.yaml | 18 +++++++++++ data/dataset.py | 31 ++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 config/config.yaml create mode 100644 data/dataset.py diff --git a/classifier_training.py b/classifier_training.py index e69de29..2d981f7 100644 --- a/classifier_training.py +++ b/classifier_training.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models +from omegaconf import DictConfig, OmegaConf +from data.dataset import get_dataset +import time + +import hydra + + +def get_model(config): + if config.dataset == 'gender_dataset': + num_class=2 + else: + num_class=307 + + # Changing number of model's output classes to 1 + #for resnet18 + if config.classifier.model == 'resnet18': + model = models.resnet18(pretrained=False) + model.fc = nn.Linear(512, num_class) + + + + #for resnet50 + elif config.classifier.model == 'resnet101': + model = models.resnet101(pretrained=False) + model.fc = nn.Linear(2048, num_class) + + + + # for densenet 121 + elif config.classifier.model == 'mnasnet': + model = models.mnasnet1_0(pretrained=True) + num_features = model.classifier[1].in_features + model.classifier[1] = nn.Linear(num_features, num_class) + # model.classifier = nn.Linear(1024, 1) + + + + #for vgg19_bn + elif config.classifier.model == 'densenet121': + model = models.densenet121(pretrained=True) + num_features = model.classifier.in_features + model.fc = nn.Linear(num_features, num_class) + + + + + # Transfer execution to GPU + model = model.to('cuda') + + return model + +@hydra.main(version_base=None, config_path="./config", config_name="config") +def main(cfg: DictConfig) -> None: + model=get_model(cfg) + train_dataloader,test_dataloader=get_dataset(cfg) + num_epochs = cfg.classifier.num_epochs + start_time = time.time() + + + + + + + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml new file mode 100644 index 0000000..0142c89 --- /dev/null +++ b/config/config.yaml @@ -0,0 +1,18 @@ +dataset: gender_dataset +classifier: + model: resnet18 + lr: 0.01 + momentum: 0.9 + num_epochs: 200 + + +paths: + gender_dataset: ./CelebA_HQ_face_gender_dataset + identity_dataset: ./CelebA_HQ_face_identity_dataset + password: secret + + +optim: + batch_size: 8 + num_epochs: 200 + num_workers : 4 \ No newline at end of file diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000..da3b3bf --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,31 @@ +import torch +import torchvision +from torchvision import datasets, models, transforms +import torch.nn as nn +import os + +transforms_train = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomHorizontalFlip(), # data augmentation + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization +]) + +transforms_test = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) +]) + + +def get_dataset(config): + if config.dataset == 'gender_dataset': + path=config.paths.gender_dataset + else: + path=config.paths.identity_dataset + train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train) + test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test) + + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) + return train_dataloader,test_dataloader \ No newline at end of file