add dataset
This commit is contained in:
parent
2044233dd1
commit
2ca63fd5a8
|
|
@ -0,0 +1,73 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torchvision import models
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
from data.dataset import get_dataset
|
||||||
|
import time
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(config):
|
||||||
|
if config.dataset == 'gender_dataset':
|
||||||
|
num_class=2
|
||||||
|
else:
|
||||||
|
num_class=307
|
||||||
|
|
||||||
|
# Changing number of model's output classes to 1
|
||||||
|
#for resnet18
|
||||||
|
if config.classifier.model == 'resnet18':
|
||||||
|
model = models.resnet18(pretrained=False)
|
||||||
|
model.fc = nn.Linear(512, num_class)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#for resnet50
|
||||||
|
elif config.classifier.model == 'resnet101':
|
||||||
|
model = models.resnet101(pretrained=False)
|
||||||
|
model.fc = nn.Linear(2048, num_class)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# for densenet 121
|
||||||
|
elif config.classifier.model == 'mnasnet':
|
||||||
|
model = models.mnasnet1_0(pretrained=True)
|
||||||
|
num_features = model.classifier[1].in_features
|
||||||
|
model.classifier[1] = nn.Linear(num_features, num_class)
|
||||||
|
# model.classifier = nn.Linear(1024, 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#for vgg19_bn
|
||||||
|
elif config.classifier.model == 'densenet121':
|
||||||
|
model = models.densenet121(pretrained=True)
|
||||||
|
num_features = model.classifier.in_features
|
||||||
|
model.fc = nn.Linear(num_features, num_class)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Transfer execution to GPU
|
||||||
|
model = model.to('cuda')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
model=get_model(cfg)
|
||||||
|
train_dataloader,test_dataloader=get_dataset(cfg)
|
||||||
|
num_epochs = cfg.classifier.num_epochs
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
dataset: gender_dataset
|
||||||
|
classifier:
|
||||||
|
model: resnet18
|
||||||
|
lr: 0.01
|
||||||
|
momentum: 0.9
|
||||||
|
num_epochs: 200
|
||||||
|
|
||||||
|
|
||||||
|
paths:
|
||||||
|
gender_dataset: ./CelebA_HQ_face_gender_dataset
|
||||||
|
identity_dataset: ./CelebA_HQ_face_identity_dataset
|
||||||
|
password: secret
|
||||||
|
|
||||||
|
|
||||||
|
optim:
|
||||||
|
batch_size: 8
|
||||||
|
num_epochs: 200
|
||||||
|
num_workers : 4
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torchvision import datasets, models, transforms
|
||||||
|
import torch.nn as nn
|
||||||
|
import os
|
||||||
|
|
||||||
|
transforms_train = transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.RandomHorizontalFlip(), # data augmentation
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # normalization
|
||||||
|
])
|
||||||
|
|
||||||
|
transforms_test = transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config):
|
||||||
|
if config.dataset == 'gender_dataset':
|
||||||
|
path=config.paths.gender_dataset
|
||||||
|
else:
|
||||||
|
path=config.paths.identity_dataset
|
||||||
|
train_dataset = datasets.ImageFolder(os.path.join(path, 'train'), transforms_train)
|
||||||
|
test_dataset = datasets.ImageFolder(os.path.join(path, 'test'), transforms_test)
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||||
|
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||||
|
return train_dataloader,test_dataloader
|
||||||
Loading…
Reference in New Issue