From 4f6194e2af5be90408dc6c1db16bfcc209802e1b Mon Sep 17 00:00:00 2001 From: Li Wenyun Date: Sun, 3 Dec 2023 16:19:54 +0800 Subject: [PATCH] add adv model --- GanAttack.py | 84 +++++++++++++++++++++++++++++++++++++++++++++- config/config.yaml | 3 +- model.py | 14 +++++--- utils.py | 37 +++++++++++++++++++- 4 files changed, 131 insertions(+), 7 deletions(-) diff --git a/GanAttack.py b/GanAttack.py index 73cb849..290e024 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -7,6 +7,7 @@ from data.dataset import get_dataset from utils import get_model from model.GanInverter.models.stylegan2.model import Generator from model.GanInverter.inference.two_stage_inference import TwoStageInference +from model import GanAttack import time import hydra @@ -32,4 +33,85 @@ def get_stylegan_inverter(cfg): path=cfg.paths.inverter_cfg inverter=TwoStageInference(opts=path) - return inverter.inverse \ No newline at end of file + return inverter.inverse + +@hydra.main(version_base=None, config_path="./config", config_name="config") +def main(cfg: DictConfig) -> None: + model=get_model(cfg) + train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) + + classifier=get_model(cfg) + g_ema, _=get_stylegan_generator(cfg) + + inverter=get_stylegan_inverter(cfg) + model=GanAttack(g_ema,inverter,images_resize=cfg.optim.images_resize,prompt=cfg.prompt).to(device) + + num_epochs = cfg.optim.num_epochs + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=cfg.classifier.lr, momentum=cfg.classifier.momentum) + start_time = time.time() + + for epoch in range(num_epochs): + model.train() + + running_loss = 0 + running_corrects = 0 + + for i, (inputs, labels) in enumerate(train_dataloader): + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / len(train_dataset) + epoch_acc = running_corrects / len(train_dataset) * 100. + print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) + + + model.eval() + + with torch.no_grad(): + running_loss = 0. + running_corrects = 0 + + for inputs, labels in test_dataloader: + inputs = inputs.to(device) + labels = labels.to(device) + + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + epoch_loss = running_loss / len(test_dataset) + epoch_acc = running_corrects / len(test_dataset) * 100. + print('[Test #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time)) + + save_path = '{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset) + torch.save(model.state_dict(), save_path) + + + + + + + + + + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml index 8c7583c..5a7f645 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -22,4 +22,5 @@ prompt: red lipstick optim: batch_size: 8 num_epochs: 200 - num_workers : 4 \ No newline at end of file + num_workers : 4 + images_resize: 256 \ No newline at end of file diff --git a/model.py b/model.py index 9495253..7c0a21f 100644 --- a/model.py +++ b/model.py @@ -6,7 +6,15 @@ import os import clip device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -model, preprocess = clip.load("ViT-B/32", device=device) + + +def get_prompt(cfg): + model, preprocess = clip.load("ViT-B/32", device=device) + prompt=cfg.prompt + text=clip.tokenize(prompt).to(device) + with torch.no_grad(): + prompt = model.encode_text(text) + return prompt # class GanAttack(nn.Module): class GanAttack(nn.Module): @@ -17,9 +25,7 @@ class GanAttack(nn.Module): self.generator.eval() self.inverter=inverter self.images_resize=images_resize - text=clip.tokenize(prompt).to(device) - with torch.no_grad(): - self.prompt = model.encode_text(text) + self.prompt=prompt text_len=self.prompt.shape[0] self.mlp=nn.Sequential( nn.Linear(text_len+512, 4096), diff --git a/utils.py b/utils.py index 75b9cfe..59e8aad 100644 --- a/utils.py +++ b/utils.py @@ -45,4 +45,39 @@ def get_model(config): # Transfer execution to GPU model = model.to('cuda') - return model \ No newline at end of file + return model + +def unnormalize(image): + mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float() + std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float() + + image = image.detach().cpu() + image *= std + image += mean + image[image < 0] = 0 + image[image > 1] = 1 + + return image + +def normalize(image): + mean = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda() + std = torch.tensor([0.5, 0.5, 0.5]).view(-1, 3, 1, 1).float().cuda() + + image = image.clone() + image -= mean + image /= std + + return image + +def set_requires_grad( nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad \ No newline at end of file