update other
This commit is contained in:
parent
44b21f20e2
commit
3940991a46
28
GanAttack.py
28
GanAttack.py
|
|
@ -105,26 +105,26 @@ def main(cfg: DictConfig) -> None:
|
||||||
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
print('Evaluating!')
|
print('Evaluating!')
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
running_loss = 0. #test_dataloader
|
running_loss = 0. #test_dataloader
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels) in enumerate(test_dataloader):
|
for i, (inputs, labels) in enumerate(test_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
|
||||||
generated_img,adv_latent_codes=model(inputs)
|
generated_img,adv_latent_codes=model(inputs)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
preds=criterion(outputs ,labels)
|
preds=criterion(outputs ,labels)
|
||||||
|
|
||||||
# running_loss += loss.item() * inputs.size(0)
|
# running_loss += loss.item() * inputs.size(0)
|
||||||
running_corrects += torch.sum(preds == labels.data)
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
# epoch_loss = running_loss / len(test_dataset)
|
# epoch_loss = running_loss / len(test_dataset)
|
||||||
epoch_acc = running_corrects / len(test_dataset) * 100.
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
||||||
print('[Test #{}] Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_acc, time.time() - start_time))
|
print('[Test ] Acc: {:.4f}% Time: {:.4f}s'.format( epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
|
save_path = '{}/stylegan_{}_{}_{}.pth'.format(cfg.paths.pretrained_models, cfg.classifier.model, cfg.dataset,cfg.prompt)
|
||||||
torch.save(model.state_dict(), save_path)
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
model=get_model(cfg)
|
model=get_model(cfg)
|
||||||
|
model.train()
|
||||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
||||||
num_epochs = cfg.classifier.num_epochs
|
num_epochs = cfg.classifier.num_epochs
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,14 @@ classifier:
|
||||||
model: resnet18
|
model: resnet18
|
||||||
lr: 0.01
|
lr: 0.01
|
||||||
momentum: 0.9
|
momentum: 0.9
|
||||||
num_epochs: 20
|
num_epochs: 120
|
||||||
num_workers : 4
|
num_workers : 4
|
||||||
batch_size: 64
|
batch_size: 128
|
||||||
|
|
||||||
|
|
||||||
paths:
|
paths:
|
||||||
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
gender_dataset: ./dataset/CelebA_HQ_face_gender_dataset
|
||||||
identity_dataset: /dataset/face_identity/CelebA_HQ_facial_identity_dataset
|
identity_dataset: dataset/CelebA_HQ_facial_identity_dataset
|
||||||
inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt
|
inverter_cfg: /dataset/face_identity/psp_ffhq_encode.pt
|
||||||
classifier: checkpoint
|
classifier: checkpoint
|
||||||
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
stylegan: checkpoint/stylegan2-ffhq-config-f.pt
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue