diff --git a/GanAttack.py b/GanAttack.py index c0d7225..f51df0f 100644 --- a/GanAttack.py +++ b/GanAttack.py @@ -7,16 +7,10 @@ from torchvision import models from omegaconf import DictConfig, OmegaConf from data.dataset import get_dataset,get_adv_dataset from utils import get_model,set_requires_grad,unnormalize -# from GanInverter.inference.two_stage_inference import TwoStageInference -# from GanInverter.models.stylegan2.model import Generator -# from models import GanAttack -# from pixel2style2pixel.scripts.align_all_parallel import align_face -# from pixel2style2pixel.models.stylegan2.model import Generator from model import GanAttack,CLIPLoss,VggLoss from prompt import get_prompt import torch.nn.functional as F import time -# from torchsummary import summary import hydra @@ -42,7 +36,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @hydra.main(version_base=None, config_path="./config", config_name="config") def main(cfg: DictConfig) -> None: model=get_model(cfg) - train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) + train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg) classifier=get_model(cfg) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) @@ -73,10 +67,10 @@ def main(cfg: DictConfig) -> None: running_loss = 0 running_corrects = 0 - for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader): + for i, (inputs, labels,detail_code) in enumerate(train_dataloader): inputs = inputs.to(device) labels = labels.to(device) - base_code=base_code.to(device) + # base_code=base_code.to(device) detail_code=detail_code.to(device) # codes = model.net.encoder(inputs) generated_img,adv_latent_codes=model(inputs,base_code,detail_code) diff --git a/data/dataset.py b/data/dataset.py index 5cd16c4..5c6735b 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -22,10 +22,10 @@ transforms_test = transforms.Compose([ ]) class ImageDataset(Dataset): - def __init__(self, data_path, base_code,detail_code, mode, transform=None): + def __init__(self, data_path, detail_code, mode, transform=None): self.path=data_path data_dir=pathlib.Path(data_path) - self.base_dir=base_code + self.detail_dir=detail_code self.mode=mode self.transform=transform @@ -44,15 +44,15 @@ class ImageDataset(Dataset): def __getitem__(self, index): img = Image.open(os.path.join(self.path, self.image_path[index])) img = img.convert('RGB') - base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy'))) + # base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy'))) detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy'))) - base_code=torch.from_numpy(base_code) + # base_code=torch.from_numpy(base_code) detail_code=torch.from_numpy(detail_code) if self.transform is not None: img = self.transform(img) label = torch.LongTensor([self.image_label[index]]) # image_path=self.image_path[index] - return img, label,base_code,detail_code + return img, label,detail_code def __len__(self): return len(self.image_path) @@ -75,10 +75,10 @@ def get_adv_dataset(config): path=config.paths.gender_dataset else: path=config.paths.identity_dataset - base_code=config.paths.base_code + # base_code=config.paths.base_code detail_code=config.paths.detail_code - train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train) - test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test) + train_dataset = ImageDataset(path,detail_code,'train',transforms_train) + test_dataset= ImageDataset(path,detail_code,'test',transforms_test) train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)