import torch import torch.nn as nn import torch.optim as optim from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset from utils import get_model import time import hydra device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) model.train() train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) num_epochs = cfg.classifier.num_epochs criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) start_time = time.time() for epoch in range(num_epochs): model.train() running_loss = 0 running_corrects = 0 for i, (inputs, labels) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(train_dataset) epoch_acc = running_corrects / len(train_dataset) * 100. print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) model.eval() with torch.no_grad(): running_loss = 0. running_corrects = 0 for inputs, labels in test_dataloader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(test_dataset) epoch_acc = running_corrects / len(test_dataset) * 100. print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset) torch.save(model.state_dict(), save_path) if __name__ == "__main__": main()