add adv model
This commit is contained in:
parent
037d945fe3
commit
4f6194e2af
84
GanAttack.py
84
GanAttack.py
|
|
@ -7,6 +7,7 @@ from data.dataset import get_dataset
|
||||||
from utils import get_model
|
from utils import get_model
|
||||||
from model.GanInverter.models.stylegan2.model import Generator
|
from model.GanInverter.models.stylegan2.model import Generator
|
||||||
from model.GanInverter.inference.two_stage_inference import TwoStageInference
|
from model.GanInverter.inference.two_stage_inference import TwoStageInference
|
||||||
|
from model import GanAttack
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
@ -32,4 +33,85 @@ def get_stylegan_inverter(cfg):
|
||||||
path=cfg.paths.inverter_cfg
|
path=cfg.paths.inverter_cfg
|
||||||
inverter=TwoStageInference(opts=path)
|
inverter=TwoStageInference(opts=path)
|
||||||
|
|
||||||
return inverter.inverse
|
return inverter.inverse
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
|
def main(cfg: DictConfig) -> None:
|
||||||
|
model=get_model(cfg)
|
||||||
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
||||||
|
|
||||||
|
classifier=get_model(cfg)
|
||||||
|
g_ema, _=get_stylegan_generator(cfg)
|
||||||
|
|
||||||
|
inverter=get_stylegan_inverter(cfg)
|
||||||
|
model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device)
|
||||||
|
|
||||||
|
num_epochs = cfg.optim.num_epochs
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
running_loss = 0
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
for i, (inputs, labels) in enumerate(train_dataloader):
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(train_dataset)
|
||||||
|
epoch_acc = running_corrects / len(train_dataset) * 100.
|
||||||
|
print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
running_loss = 0.
|
||||||
|
running_corrects = 0
|
||||||
|
|
||||||
|
for inputs, labels in test_dataloader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
_, preds = torch.max(outputs, 1)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
running_loss += loss.item() * inputs.size(0)
|
||||||
|
running_corrects += torch.sum(preds == labels.data)
|
||||||
|
|
||||||
|
epoch_loss = running_loss / len(test_dataset)
|
||||||
|
epoch_acc = running_corrects / len(test_dataset) * 100.
|
||||||
|
print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))
|
||||||
|
|
||||||
|
save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)
|
||||||
|
torch.save(model.state_dict(), save_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -22,4 +22,5 @@ prompt: red lipstick
|
||||||
optim:
|
optim:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
num_epochs: 200
|
num_epochs: 200
|
||||||
num_workers : 4
|
num_workers : 4
|
||||||
|
images_resize: 256
|
||||||
14
model.py
14
model.py
|
|
@ -6,7 +6,15 @@ import os
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
|
||||||
|
|
||||||
|
def get_prompt(cfg):
|
||||||
|
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||||
|
prompt=cfg.prompt
|
||||||
|
text=clip.tokenize(prompt).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt = model.encode_text(text)
|
||||||
|
return prompt
|
||||||
# class GanAttack(nn.Module):
|
# class GanAttack(nn.Module):
|
||||||
|
|
||||||
class GanAttack(nn.Module):
|
class GanAttack(nn.Module):
|
||||||
|
|
@ -17,9 +25,7 @@ class GanAttack(nn.Module):
|
||||||
self.generator.eval()
|
self.generator.eval()
|
||||||
self.inverter=inverter
|
self.inverter=inverter
|
||||||
self.images_resize=images_resize
|
self.images_resize=images_resize
|
||||||
text=clip.tokenize(prompt).to(device)
|
self.prompt=prompt
|
||||||
with torch.no_grad():
|
|
||||||
self.prompt = model.encode_text(text)
|
|
||||||
text_len=self.prompt.shape[0]
|
text_len=self.prompt.shape[0]
|
||||||
self.mlp=nn.Sequential(
|
self.mlp=nn.Sequential(
|
||||||
nn.Linear(text_len+512, 4096),
|
nn.Linear(text_len+512, 4096),
|
||||||
|
|
|
||||||
37
utils.py
37
utils.py
|
|
@ -45,4 +45,39 @@ def get_model(config):
|
||||||
# Transfer execution to GPU
|
# Transfer execution to GPU
|
||||||
model = model.to('cuda')
|
model = model.to('cuda')
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def unnormalize(image):
|
||||||
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
||||||
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float()
|
||||||
|
|
||||||
|
image = image.detach().cpu()
|
||||||
|
image *= std
|
||||||
|
image += mean
|
||||||
|
image[image < 0] = 0
|
||||||
|
image[image > 1] = 1
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def normalize(image):
|
||||||
|
mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
||||||
|
std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda()
|
||||||
|
|
||||||
|
image = image.clone()
|
||||||
|
image -= mean
|
||||||
|
image /= std
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def set_requires_grad( nets, requires_grad=False):
|
||||||
|
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
||||||
|
Parameters:
|
||||||
|
nets (network list) -- a list of networks
|
||||||
|
requires_grad (bool) -- whether the networks require gradients or not
|
||||||
|
"""
|
||||||
|
if not isinstance(nets, list):
|
||||||
|
nets = [nets]
|
||||||
|
for net in nets:
|
||||||
|
if net is not None:
|
||||||
|
for param in net.parameters():
|
||||||
|
param.requires_grad = requires_grad
|
||||||
Loading…
Reference in New Issue