update other

This commit is contained in:
leewlving 2023-12-10 16:29:02 +08:00
parent 44b21f20e2
commit 3940991a46
3 changed files with 18 additions and 17 deletions

View File

@ -105,26 +105,26 @@ def main(cfg: DictConfig) -> None:
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval() model.eval()
print('Evaluating!') print('Evaluating!')
with torch.no_grad(): with torch.no_grad():
running_loss = 0. #test_dataloader running_loss = 0. #test_dataloader
running_corrects = 0 running_corrects = 0
for i, (inputs, labels) in enumerate(test_dataloader): for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
generated_img,adv_latent_codes=model(inputs) generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img) outputs = classifier(generated_img)
preds=criterion(outputs ,labels) preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0) # running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset) # epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100. epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time)) print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt) save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)

View File

@ -16,6 +16,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@hydra.main(version_base=None, config_path="./config", config_name="config") @hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
model=get_model(cfg) model=get_model(cfg)
model.train()
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
num_epochs = cfg.classifier.num_epochs num_epochs = cfg.classifier.num_epochs
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()

View File

@ -3,14 +3,14 @@ classifier:
model: resnet18 model: resnet18
lr: 0.01 lr: 0.01
momentum: 0.9 momentum: 0.9
num_epochs: 20 num_epochs: 120
num_workers : 4 num_workers : 4
batch_size: 64 batch_size: 128
paths: paths:
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset identity_dataset: dataset/CelebA_HQ_facial_identity_dataset
inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt
classifier: checkpoint classifier: checkpoint
stylegan: checkpoint/stylegan2-ffhq-config-f.pt stylegan: checkpoint/stylegan2-ffhq-config-f.pt