add dataset

This commit is contained in:
Li Wenyun 2023-12-02 22:02:58 +08:00
parent 2044233dd1
commit 2ca63fd5a8
3 changed files with 122 additions and 0 deletions

View File

@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset
import time
import hydra
def get_model(config):
if config.dataset == 'gender_dataset':
num_class=2
else:
num_class=307
# Changing number of model's output classes to 1
#for resnet18
if config.classifier.model == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_class)
#for resnet50
elif config.classifier.model == 'resnet101':
model = models.resnet101(pretrained=False)
model.fc = nn.Linear(2048, num_class)
# for densenet 121
elif config.classifier.model == 'mnasnet':
model = models.mnasnet1_0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, num_class)
# model.classifier = nn.Linear(1024, 1)
#for vgg19_bn
elif config.classifier.model == 'densenet121':
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.fc = nn.Linear(num_features, num_class)
# Transfer execution to GPU
model = model.to('cuda')
return model
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader=get_dataset(cfg)
num_epochs = cfg.classifier.num_epochs
start_time = time.time()
if __name__ == "__main__":
main()

18
config/config.yaml Normal file
View File

@ -0,0 +1,18 @@
dataset: gender_dataset
classifier:
model: resnet18
lr: 0.01
momentum: 0.9
num_epochs: 200
paths:
gender_dataset: ./CelebA_HQ_face_gender_dataset
identity_dataset: ./CelebA_HQ_face_identity_dataset
password: secret
optim:
batch_size: 8
num_epochs: 200
num_workers : 4

31
data/dataset.py Normal file
View File

@ -0,0 +1,31 @@
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import os
transforms_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(), # data augmentation
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
])
transforms_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def get_dataset(config):
if config.dataset == 'gender_dataset':
path=config.paths.gender_dataset
else:
path=config.paths.identity_dataset
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
return train_dataloader,test_dataloader