diff --git a/GanAttack.py b/GanAttack.py index d688dab..bbe189e 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -105,26 +105,26 @@ def main(cfg: DictConfig) -> None: print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) - model.eval() - print('Evaluating!') - with torch.no_grad(): - running_loss = 0. #test_dataloader - running_corrects = 0 + model.eval() + print('Evaluating!') + with torch.no_grad(): + running_loss = 0. #test_dataloader + running_corrects = 0 - for i, (inputs, labels) in enumerate(test_dataloader): - inputs = inputs.to(device) - labels = labels.to(device) + for i, (inputs, labels) in enumerate(test_dataloader): + inputs = inputs.to(device) + labels = labels.to(device) - generated_img,adv_latent_codes=model(inputs) - outputs = classifier(generated_img) - preds=criterion(outputs ,labels) + generated_img,adv_latent_codes=model(inputs) + outputs = classifier(generated_img) + preds=criterion(outputs ,labels) # running_loss += loss.item() * inputs.size(0) - running_corrects += torch.sum(preds == labels.data) + running_corrects += torch.sum(preds == labels.data) # epoch_loss = running_loss / len(test_dataset) - epoch_acc = running_corrects / len(test_dataset) * 100. - print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time)) + epoch_acc = running_corrects / len(test_dataset) * 100. + print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time)) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) torch.save(model.state_dict(), save_path) diff --git a/classifier_training.py b/classifier_training.py index 86f042f..70bfa56 100644 --- a/classifier_training.py +++ b/classifier_training.py @@ -16,6 +16,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) + model.train() train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) num_epochs = cfg.classifier.num_epochs criterion = nn.CrossEntropyLoss() diff --git a/config/config.yaml b/config/config.yaml index c03f642..a6fd6bd 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -3,14 +3,14 @@ classifier: model: resnet18 lr: 0.01 momentum: 0.9 - num_epochs: 20 + num_epochs: 120 num_workers : 4 - batch_size: 64 + batch_size: 128 paths: gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset - identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset + identity_dataset: dataset/CelebA_HQ_facial_identity_dataset inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt classifier: checkpoint stylegan: checkpoint/stylegan2-ffhq-config-f.pt