add cw loss
This commit is contained in:
parent
1c446414d0
commit
7671b4496e
|
|
@ -19,7 +19,26 @@ import hydra
|
|||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_labels(labels,NUM_CLASSES=307):
|
||||
|
||||
targets = torch.zeros_like(labels)
|
||||
for i in range(len(labels)):
|
||||
rand_v = torch.randint(0, NUM_CLASSES-1)
|
||||
while labels[i]==rand_v:
|
||||
rand_v = torch.randint(0, NUM_CLASSES-1)
|
||||
targets[i] = rand_v
|
||||
# targets = targets.astype(np.int32)
|
||||
|
||||
return targets
|
||||
|
||||
def cw_loss(outputs, labels):
|
||||
one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels]
|
||||
|
||||
# find the max logit other than the target class
|
||||
other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
|
||||
# get the target class's logit
|
||||
real = torch.max(one_hot_labels * outputs, dim=1)[0]
|
||||
return torch.clamp((other - real), min=-0.)
|
||||
|
||||
def get_stylegan_generator(cfg):
|
||||
generator=Generator(1024, 512, 8)
|
||||
|
|
@ -53,13 +72,7 @@ def main(cfg: DictConfig) -> None:
|
|||
max_loss=nn.MarginRankingLoss(0.1)
|
||||
clip_loss=CLIPLoss().to(device)
|
||||
loss_fn_vgg = lpips.LPIPS(net='vgg').to(device)
|
||||
# vgg_loss=VggLoss().to(device)
|
||||
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||
# set_requires_grad(model.mlp.parameters())
|
||||
# for p in (model.mlp.parameters()):
|
||||
# p.requires_grad =True
|
||||
# for p in (model.noise_mlp.parameters()):
|
||||
# p.requires_grad =True
|
||||
|
||||
model.train()
|
||||
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||
start_time = time.time()
|
||||
|
|
@ -73,22 +86,21 @@ def main(cfg: DictConfig) -> None:
|
|||
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
target=generate_labels(labels)
|
||||
# base_code=base_code.to(device)
|
||||
detail_code=detail_code.to(device)
|
||||
# codes = model.net.encoder(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code,target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
# generated_img,adv_latent_codes=model(inputs)
|
||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
adv_outputs = classifier(generated_img)
|
||||
clean_outputs=classifier(inputs)
|
||||
adv_preds=criterion(adv_outputs,labels)
|
||||
clean_preds=criterion(clean_outputs,labels)
|
||||
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
||||
loss_classifier=cw_loss(adv_outputs,target)
|
||||
# clean_outputs=classifier(inputs)
|
||||
# adv_preds=criterion(adv_outputs,labels)
|
||||
# clean_preds=criterion(clean_outputs,labels)
|
||||
# loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
||||
# _, preds = torch.max(classifier(generated_img), 1)
|
||||
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,15 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|||
# mean_latent = g_ema.mean_latent(4096)
|
||||
|
||||
# return g_ema, mean_latent
|
||||
def cw_loss(outputs, labels):
|
||||
one_hot_labels = torch.eye(outputs.shape[1]).to(device)[labels]
|
||||
|
||||
# find the max logit other than the target class
|
||||
other = torch.max((1 - one_hot_labels) * outputs, dim=1)[0]
|
||||
# get the target class's logit
|
||||
real = torch.max(one_hot_labels * outputs, dim=1)[0]
|
||||
return torch.clamp((real - other), min=-0.)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -59,10 +68,7 @@ def main(cfg: DictConfig) -> None:
|
|||
# vgg_loss=VggLoss().to(device)
|
||||
# summary(model, input_size = (3, 256, 256), batch_size = 5)
|
||||
# set_requires_grad(model.mlp.parameters())
|
||||
for p in (model.mlp.parameters()):
|
||||
p.requires_grad =True
|
||||
for p in (model.noise_mlp.parameters()):
|
||||
p.requires_grad =True
|
||||
|
||||
optimizer = optim.SGD(model.mlp.parameters()+model.noise_mlp.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -79,23 +85,12 @@ def main(cfg: DictConfig) -> None:
|
|||
detail_code=detail_code.to(device)
|
||||
# codes = model.net.encoder(inputs)
|
||||
generated_img,adv_latent_codes=model(inputs,detail_code)
|
||||
# _, _, _, clean_refine_images, clean_latent_codes, _=inverter(inputs,img_path)
|
||||
optimizer.zero_grad()
|
||||
# generated_img,adv_latent_codes=model(inputs)
|
||||
# loss_vgg=vgg_loss(inputs,generated_img)
|
||||
# loss_l1=F.l1_loss(base_code,adv_latent_codes)
|
||||
loss_vgg=loss_fn_vgg(inputs,generated_img)
|
||||
loss_clip=clip_loss(generated_img,prompt)
|
||||
adv_outputs = classifier(generated_img)
|
||||
clean_outputs=classifier(inputs)
|
||||
adv_preds=criterion(adv_outputs,labels)
|
||||
clean_preds=criterion(clean_outputs,labels)
|
||||
loss_classifier=max_loss(clean_preds,adv_preds,torch.ones_like(preds))
|
||||
# _, preds = torch.max(classifier(generated_img), 1)
|
||||
# loss_classifier=max_loss(torch.ones_like(criterion(outputs, labels)),criterion(outputs, labels),criterion(outputs, labels))
|
||||
loss_classifier=cw_loss(adv_outputs,labels)
|
||||
|
||||
# loss=loss_vgg+cfg.optim.alpha*loss_l1+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
# loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss=loss_vgg+cfg.optim.beta*loss_clip+cfg.optim.delta*loss_classifier
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
|
|||
Loading…
Reference in New Issue