48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
|
|
def get_model(config):
|
|
if config.dataset == 'gender_dataset':
|
|
num_class=2
|
|
else:
|
|
num_class=307
|
|
|
|
# Changing number of model's output classes to 1
|
|
#for resnet18
|
|
if config.classifier.model == 'resnet18':
|
|
model = models.resnet18(pretrained=False)
|
|
model.fc = nn.Linear(512, num_class)
|
|
|
|
|
|
|
|
#for resnet50
|
|
elif config.classifier.model == 'resnet101':
|
|
model = models.resnet101(pretrained=False)
|
|
model.fc = nn.Linear(2048, num_class)
|
|
|
|
|
|
|
|
# for densenet 121
|
|
elif config.classifier.model == 'mnasnet':
|
|
model = models.mnasnet1_0(pretrained=True)
|
|
num_features = model.classifier[1].in_features
|
|
model.classifier[1] = nn.Linear(num_features, num_class)
|
|
# model.classifier = nn.Linear(1024, 1)
|
|
|
|
|
|
|
|
#for vgg19_bn
|
|
elif config.classifier.model == 'densenet121':
|
|
model = models.densenet121(pretrained=True)
|
|
num_features = model.classifier.in_features
|
|
model.fc = nn.Linear(num_features, num_class)
|
|
|
|
|
|
|
|
|
|
# Transfer execution to GPU
|
|
model = model.to('cuda')
|
|
|
|
return model |