88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from data.dataset import get_dataset
|
|
from utils import get_model
|
|
import time
|
|
|
|
import hydra
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
|
def main(cfg: DictConfig) -> None:
|
|
model=get_model(cfg)
|
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
|
num_epochs = cfg.classifier.num_epochs
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
|
start_time = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train()
|
|
|
|
running_loss = 0
|
|
running_corrects = 0
|
|
|
|
for i, (inputs, labels) in enumerate(train_dataloader):
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(inputs)
|
|
_, preds = torch.max(outputs, 1)
|
|
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
epoch_loss = running_loss / len(train_dataset)
|
|
epoch_acc = running_corrects / len(train_dataset) * 100.
|
|
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
|
|
|
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
running_loss = 0.
|
|
running_corrects = 0
|
|
|
|
for inputs, labels in test_dataloader:
|
|
inputs = inputs.to(device)
|
|
labels = labels.to(device)
|
|
|
|
outputs = model(inputs)
|
|
_, preds = torch.max(outputs, 1)
|
|
loss = criterion(outputs, labels)
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
epoch_loss = running_loss / len(test_dataset)
|
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
|
print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
|
|
|
save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |