update GanAttack
This commit is contained in:
parent
cf67f37cd3
commit
a6080d42a7
10
GanAttack.py
10
GanAttack.py
|
|
@ -97,9 +97,12 @@ def main(cfg: DictConfig) -> None:
|
||||||
loss_vgg=vgg_loss(inputs,generated_img)
|
loss_vgg=vgg_loss(inputs,generated_img)
|
||||||
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
loss_l1=F.l1_loss(codes,adv_latent_codes)
|
||||||
loss_clip=clip_loss(generated_img,prompt)
|
loss_clip=clip_loss(generated_img,prompt)
|
||||||
|
outputs = classifier(generated_img)
|
||||||
|
preds=criterion(outputs,labels)
|
||||||
|
loss_classifier=max_loss(torch.ones_like(preds),preds,-torch.ones_like(preds))
|
||||||
|
# _, preds = torch.max(classifier(generated_img), 1)
|
||||||
|
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||||
|
|
||||||
_, preds = torch.max(classifier(generated_img), 1)
|
|
||||||
loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
|
||||||
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -124,8 +127,7 @@ def main(cfg: DictConfig) -> None:
|
||||||
|
|
||||||
generated_img,adv_latent_codes=model(inputs)
|
generated_img,adv_latent_codes=model(inputs)
|
||||||
outputs = classifier(generated_img)
|
outputs = classifier(generated_img)
|
||||||
_, preds = torch.max(outputs, 1)
|
preds=criterion(outputs ,labels)
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
|
|
||||||
# running_loss += loss.item() * inputs.size(0)
|
# running_loss += loss.item() * inputs.size(0)
|
||||||
running_corrects += torch.sum(preds == labels.data)
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue