73 lines
1.6 KiB
Python
73 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from data.dataset import get_dataset
|
|
import time
|
|
|
|
import hydra
|
|
|
|
|
|
def get_model(config):
|
|
if config.dataset == 'gender_dataset':
|
|
num_class=2
|
|
else:
|
|
num_class=307
|
|
|
|
# Changing number of model's output classes to 1
|
|
#for resnet18
|
|
if config.classifier.model == 'resnet18':
|
|
model = models.resnet18(pretrained=False)
|
|
model.fc = nn.Linear(512, num_class)
|
|
|
|
|
|
|
|
#for resnet50
|
|
elif config.classifier.model == 'resnet101':
|
|
model = models.resnet101(pretrained=False)
|
|
model.fc = nn.Linear(2048, num_class)
|
|
|
|
|
|
|
|
# for densenet 121
|
|
elif config.classifier.model == 'mnasnet':
|
|
model = models.mnasnet1_0(pretrained=True)
|
|
num_features = model.classifier[1].in_features
|
|
model.classifier[1] = nn.Linear(num_features, num_class)
|
|
# model.classifier = nn.Linear(1024, 1)
|
|
|
|
|
|
|
|
#for vgg19_bn
|
|
elif config.classifier.model == 'densenet121':
|
|
model = models.densenet121(pretrained=True)
|
|
num_features = model.classifier.in_features
|
|
model.fc = nn.Linear(num_features, num_class)
|
|
|
|
|
|
|
|
|
|
# Transfer execution to GPU
|
|
model = model.to('cuda')
|
|
|
|
return model
|
|
|
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
model=get_model(cfg)
|
|
train_dataloader,test_dataloader=get_dataset(cfg)
|
|
num_epochs = cfg.classifier.num_epochs
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |