diff --git a/GanAttack.py b/GanAttack.py index 7a1f925..ff69b93 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -97,9 +97,12 @@ def main(cfg: DictConfig) -> None: loss_vgg=vgg_loss(inputs,generated_img) loss_l1=F.l1_loss(codes,adv_latent_codes) loss_clip=clip_loss(generated_img,prompt) + outputs = classifier(generated_img) + preds=criterion(outputs,labels) + loss_classifier=max_loss(torch.ones_like(preds),preds,-torch.ones_like(preds)) + # _, preds = torch.max(classifier(generated_img), 1) + # loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) - _, preds = torch.max(classifier(generated_img), 1) - loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels)) loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier loss.backward() optimizer.step() @@ -124,8 +127,7 @@ def main(cfg: DictConfig) -> None: generated_img,adv_latent_codes=model(inputs) outputs = classifier(generated_img) - _, preds = torch.max(outputs, 1) - loss = criterion(outputs, labels) + preds=criterion(outputs ,labels) # running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data)