add dataset
This commit is contained in:
parent
2044233dd1
commit
2ca63fd5a8
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchvision import models
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from data.dataset import get_dataset
|
||||
import time
|
||||
|
||||
import hydra
|
||||
|
||||
|
||||
def get_model(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
num_class=2
|
||||
else:
|
||||
num_class=307
|
||||
|
||||
# Changing number of model's output classes to 1
|
||||
#for resnet18
|
||||
if config.classifier.model == 'resnet18':
|
||||
model = models.resnet18(pretrained=False)
|
||||
model.fc = nn.Linear(512, num_class)
|
||||
|
||||
|
||||
|
||||
#for resnet50
|
||||
elif config.classifier.model == 'resnet101':
|
||||
model = models.resnet101(pretrained=False)
|
||||
model.fc = nn.Linear(2048, num_class)
|
||||
|
||||
|
||||
|
||||
# for densenet 121
|
||||
elif config.classifier.model == 'mnasnet':
|
||||
model = models.mnasnet1_0(pretrained=True)
|
||||
num_features = model.classifier[1].in_features
|
||||
model.classifier[1] = nn.Linear(num_features, num_class)
|
||||
# model.classifier = nn.Linear(1024, 1)
|
||||
|
||||
|
||||
|
||||
#for vgg19_bn
|
||||
elif config.classifier.model == 'densenet121':
|
||||
model = models.densenet121(pretrained=True)
|
||||
num_features = model.classifier.in_features
|
||||
model.fc = nn.Linear(num_features, num_class)
|
||||
|
||||
|
||||
|
||||
|
||||
# Transfer execution to GPU
|
||||
model = model.to('cuda')
|
||||
|
||||
return model
|
||||
|
||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||
def main(cfg: DictConfig) -> None:
|
||||
model=get_model(cfg)
|
||||
train_dataloader,test_dataloader=get_dataset(cfg)
|
||||
num_epochs = cfg.classifier.num_epochs
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
dataset: gender_dataset
|
||||
classifier:
|
||||
model: resnet18
|
||||
lr: 0.01
|
||||
momentum: 0.9
|
||||
num_epochs: 200
|
||||
|
||||
|
||||
paths:
|
||||
gender_dataset: ./CelebA_HQ_face_gender_dataset
|
||||
identity_dataset: ./CelebA_HQ_face_identity_dataset
|
||||
password: secret
|
||||
|
||||
|
||||
optim:
|
||||
batch_size: 8
|
||||
num_epochs: 200
|
||||
num_workers : 4
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
import torchvision
|
||||
from torchvision import datasets, models, transforms
|
||||
import torch.nn as nn
|
||||
import os
|
||||
|
||||
transforms_train = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomHorizontalFlip(), # data augmentation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
||||
])
|
||||
|
||||
transforms_test = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
|
||||
def get_dataset(config):
|
||||
if config.dataset == 'gender_dataset':
|
||||
path=config.paths.gender_dataset
|
||||
else:
|
||||
path=config.paths.identity_dataset
|
||||
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
||||
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||
return train_dataloader,test_dataloader
|
||||
Loading…
Reference in New Issue