rm basecode
This commit is contained in:
parent
1e92eadac6
commit
98d940cbf9
12
GanAttack.py
12
GanAttack.py
|
|
@ -7,16 +7,10 @@ from torchvision import models
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from data.dataset import get_dataset,get_adv_dataset
|
from data.dataset import get_dataset,get_adv_dataset
|
||||||
from utils import get_model,set_requires_grad,unnormalize
|
from utils import get_model,set_requires_grad,unnormalize
|
||||||
# from GanInverter.inference.two_stage_inference import TwoStageInference
|
|
||||||
# from GanInverter.models.stylegan2.model import Generator
|
|
||||||
# from models import GanAttack
|
|
||||||
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
|
||||||
# from pixel2style2pixel.models.stylegan2.model import Generator
|
|
||||||
from model import GanAttack,CLIPLoss,VggLoss
|
from model import GanAttack,CLIPLoss,VggLoss
|
||||||
from prompt import get_prompt
|
from prompt import get_prompt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
# from torchsummary import summary
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
|
|
@ -42,7 +36,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
def main(cfg: DictConfig) -> None:
|
def main(cfg: DictConfig) -> None:
|
||||||
model=get_model(cfg)
|
model=get_model(cfg)
|
||||||
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg)
|
train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
|
||||||
|
|
||||||
classifier=get_model(cfg)
|
classifier=get_model(cfg)
|
||||||
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
|
||||||
|
|
@ -73,10 +67,10 @@ def main(cfg: DictConfig) -> None:
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
running_corrects = 0
|
running_corrects = 0
|
||||||
|
|
||||||
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader):
|
for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
base_code=base_code.to(device)
|
# base_code=base_code.to(device)
|
||||||
detail_code=detail_code.to(device)
|
detail_code=detail_code.to(device)
|
||||||
# codes = model.net.encoder(inputs)
|
# codes = model.net.encoder(inputs)
|
||||||
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
generated_img,adv_latent_codes=model(inputs,base_code,detail_code)
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,10 @@ transforms_test = transforms.Compose([
|
||||||
])
|
])
|
||||||
|
|
||||||
class ImageDataset(Dataset):
|
class ImageDataset(Dataset):
|
||||||
def __init__(self, data_path, base_code,detail_code, mode, transform=None):
|
def __init__(self, data_path, detail_code, mode, transform=None):
|
||||||
self.path=data_path
|
self.path=data_path
|
||||||
data_dir=pathlib.Path(data_path)
|
data_dir=pathlib.Path(data_path)
|
||||||
self.base_dir=base_code
|
|
||||||
self.detail_dir=detail_code
|
self.detail_dir=detail_code
|
||||||
self.mode=mode
|
self.mode=mode
|
||||||
self.transform=transform
|
self.transform=transform
|
||||||
|
|
@ -44,15 +44,15 @@ class ImageDataset(Dataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
img = Image.open(os.path.join(self.path, self.image_path[index]))
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
|
# base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
|
||||||
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
|
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
|
||||||
base_code=torch.from_numpy(base_code)
|
# base_code=torch.from_numpy(base_code)
|
||||||
detail_code=torch.from_numpy(detail_code)
|
detail_code=torch.from_numpy(detail_code)
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
label = torch.LongTensor([self.image_label[index]])
|
label = torch.LongTensor([self.image_label[index]])
|
||||||
# image_path=self.image_path[index]
|
# image_path=self.image_path[index]
|
||||||
return img, label,base_code,detail_code
|
return img, label,detail_code
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_path)
|
return len(self.image_path)
|
||||||
|
|
@ -75,10 +75,10 @@ def get_adv_dataset(config):
|
||||||
path=config.paths.gender_dataset
|
path=config.paths.gender_dataset
|
||||||
else:
|
else:
|
||||||
path=config.paths.identity_dataset
|
path=config.paths.identity_dataset
|
||||||
base_code=config.paths.base_code
|
# base_code=config.paths.base_code
|
||||||
detail_code=config.paths.detail_code
|
detail_code=config.paths.detail_code
|
||||||
train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train)
|
train_dataset = ImageDataset(path,detail_code,'train',transforms_train)
|
||||||
test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test)
|
test_dataset= ImageDataset(path,detail_code,'test',transforms_test)
|
||||||
|
|
||||||
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
|
||||||
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue