add binary
This commit is contained in:
parent
7671b4496e
commit
0f7414bc5f
|
|
@ -62,12 +62,9 @@ def main(cfg: DictConfig) -> None:
|
||||||
|
|
||||||
num_epochs = cfg.optim.num_epochs
|
num_epochs = cfg.optim.num_epochs
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
max_loss=nn.MarginRankingLoss(0.1)
|
|
||||||
clip_loss=CLIPLoss().to(device)
|
clip_loss=CLIPLoss().to(device)
|
||||||
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
||||||
# vgg_loss=VggLoss().to(device)
|
|
||||||
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
|
||||||
# set_requires_grad(model.mlp.parameters())
|
|
||||||
|
|
||||||
optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue