rm basecode

This commit is contained in:
leewlving 2024-01-13 18:45:20 +08:00
parent 1e92eadac6
commit 98d940cbf9
2 changed files with 11 additions and 17 deletions

View File

@ -7,16 +7,10 @@ from torchvision import models
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize from utils import get_model,set_requires_grad,unnormalize
# from GanInverter.inference.two_stage_inference import TwoStageInference
# from GanInverter.models.stylegan2.model import Generator
# from models import GanAttack
# from pixel2style2pixel.scripts.align_all_parallel import align_face
# from pixel2style2pixel.models.stylegan2.model import Generator
from model import GanAttack,CLIPLoss,VggLoss from model import GanAttack,CLIPLoss,VggLoss
from prompt import get_prompt from prompt import get_prompt
import torch.nn.functional as F import torch.nn.functional as F
import time import time
# from torchsummary import summary
import hydra import hydra
@ -42,7 +36,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@hydra.main(version_base=None, config_path="./config", config_name="config") @hydra.main(version_base=None, config_path="./config", config_name="config")
def main(cfg: DictConfig) -> None: def main(cfg: DictConfig) -> None:
model=get_model(cfg) model=get_model(cfg)
train_dataloader,test_dataloader,train_dataset,test_dataset=get_dataset(cfg) train_dataloader,test_dataloader,train_dataset,test_dataset=get_adv_dataset(cfg)
classifier=get_model(cfg) classifier=get_model(cfg)
classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset))) classifier.load_state_dict(torch.load('{}/{}_{}.pth'.format(cfg.paths.classifier, cfg.classifier.model, cfg.dataset)))
@ -73,10 +67,10 @@ def main(cfg: DictConfig) -> None:
running_loss = 0 running_loss = 0
running_corrects = 0 running_corrects = 0
for i, (inputs, labels,base_code,detail_code) in enumerate(train_dataloader): for i, (inputs, labels,detail_code) in enumerate(train_dataloader):
inputs = inputs.to(device) inputs = inputs.to(device)
labels = labels.to(device) labels = labels.to(device)
base_code=base_code.to(device) # base_code=base_code.to(device)
detail_code=detail_code.to(device) detail_code=detail_code.to(device)
# codes = model.net.encoder(inputs) # codes = model.net.encoder(inputs)
generated_img,adv_latent_codes=model(inputs,base_code,detail_code) generated_img,adv_latent_codes=model(inputs,base_code,detail_code)

View File

@ -22,10 +22,10 @@ transforms_test = transforms.Compose([
]) ])
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, data_path, base_code,detail_code, mode, transform=None): def __init__(self, data_path, detail_code, mode, transform=None):
self.path=data_path self.path=data_path
data_dir=pathlib.Path(data_path) data_dir=pathlib.Path(data_path)
self.base_dir=base_code
self.detail_dir=detail_code self.detail_dir=detail_code
self.mode=mode self.mode=mode
self.transform=transform self.transform=transform
@ -44,15 +44,15 @@ class ImageDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
img = Image.open(os.path.join(self.path, self.image_path[index])) img = Image.open(os.path.join(self.path, self.image_path[index]))
img = img.convert('RGB') img = img.convert('RGB')
base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy'))) # base_code=np.load(os.path.join(self.base_dir, self.image_path[index].replace('.jpg', '.npy')))
detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy'))) detail_code=np.load(os.path.join(self.detail_dir, self.image_path[index].replace('.jpg', '.npy')))
base_code=torch.from_numpy(base_code) # base_code=torch.from_numpy(base_code)
detail_code=torch.from_numpy(detail_code) detail_code=torch.from_numpy(detail_code)
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
label = torch.LongTensor([self.image_label[index]]) label = torch.LongTensor([self.image_label[index]])
# image_path=self.image_path[index] # image_path=self.image_path[index]
return img, label,base_code,detail_code return img, label,detail_code
def __len__(self): def __len__(self):
return len(self.image_path) return len(self.image_path)
@ -75,10 +75,10 @@ def get_adv_dataset(config):
path=config.paths.gender_dataset path=config.paths.gender_dataset
else: else:
path=config.paths.identity_dataset path=config.paths.identity_dataset
base_code=config.paths.base_code # base_code=config.paths.base_code
detail_code=config.paths.detail_code detail_code=config.paths.detail_code
train_dataset = ImageDataset(path,base_code,detail_code,'train',transforms_train) train_dataset = ImageDataset(path,detail_code,'train',transforms_train)
test_dataset= ImageDataset(path,base_code,detail_code,'test',transforms_test) test_dataset= ImageDataset(path,detail_code,'test',transforms_test)
train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers) train_dataloader= torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=True, num_workers=config.optim.num_workers)
test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers) test_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.optim.batch_size, shuffle=False, num_workers=config.optim.num_workers)