update other

This commit is contained in:
leewlving 2023-12-10 16:29:02 +08:00
parent 44b21f20e2
commit 3940991a46
3 changed files with 18 additions and 17 deletions

View File

@ -105,26 +105,26 @@ def main(cfg: DictConfig) -> None:
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
model.eval()
print('Evaluating!')
with torch.no_grad():
running_loss = 0. #test_dataloader
running_corrects = 0
for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
for i, (inputs, labels) in enumerate(test_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
generated_img,adv_latent_codes=model(inputs)
outputs = classifier(generated_img)
preds=criterion(outputs ,labels)
# running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
running_corrects += torch.sum(preds == labels.data)
# epoch_loss = running_loss / len(test_dataset)
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time))
epoch_acc = running_corrects / len(test_dataset) * 100.
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
torch.save(model.state_dict(), save_path)

View File

@ -16,6 +16,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None:
model=get_model(cfg)
model.train()
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
num_epochs = cfg.classifier.num_epochs
criterion = nn.CrossEntropyLoss()

View File

@ -3,14 +3,14 @@ classifier:
model: resnet18
lr: 0.01
momentum: 0.9
num_epochs: 20
num_epochs: 120
num_workers : 4
batch_size: 64
batch_size: 128
paths:
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset
identity_dataset: dataset/CelebA_HQ_facial_identity_dataset
inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt
classifier: checkpoint
stylegan: checkpoint/stylegan2-ffhq-config-f.pt