GanAttack/classifier_training.py

73 lines
1.6 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset
import time
import hydra
def get_model(config):
if config.dataset == 'gender_dataset':
num_class=2
else:
num_class=307
# Changing number of model's output classes to 1
#for resnet18
if config.classifier.model == 'resnet18':
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, num_class)
#for resnet50
elif config.classifier.model == 'resnet101':
model = models.resnet101(pretrained=False)
model.fc = nn.Linear(2048, num_class)
# for densenet 121
elif config.classifier.model == 'mnasnet':
model = models.mnasnet1_0(pretrained=True)
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, num_class)
# model.classifier = nn.Linear(1024, 1)
#for vgg19_bn
elif config.classifier.model == 'densenet121':
model = models.densenet121(pretrained=True)
num_features = model.classifier.in_features
model.fc = nn.Linear(num_features, num_class)
# Transfer execution to GPU
model = model.to('cuda')
return model
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
train_dataloader,test_dataloader=get_dataset(cfg)
num_epochs = cfg.classifier.num_epochs
start_time = time.time()
if __name__ == "__main__":
main()