This commit is contained in:
parent
02c8c5db1a
commit
c10ee107cc
|
|
@ -13,7 +13,8 @@ from utils import get_model,set_requires_grad,unnormalize
|
||||||
# from models import GanAttack
|
# from models import GanAttack
|
||||||
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
# from pixel2style2pixel.scripts.align_all_parallel import align_face
|
||||||
# from pixel2style2pixel.models.stylegan2.model import Generator
|
# from pixel2style2pixel.models.stylegan2.model import Generator
|
||||||
from model import GanAttack,CLIPLoss,VggLoss,get_prompt
|
from model import GanAttack,CLIPLoss,VggLoss
|
||||||
|
from prompt import get_prompt
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import time
|
import time
|
||||||
# from torchsummary import summary
|
# from torchsummary import summary
|
||||||
|
|
|
||||||
20
model.py
20
model.py
|
|
@ -9,6 +9,12 @@ from pixel2style2pixel.models.psp import pSp
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from utils import normalize
|
from utils import normalize
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append('./pixel2style2pixel')
|
||||||
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -110,4 +116,16 @@ class VggLoss(torch.nn.Module):
|
||||||
feature1=torch.flatten(feature1)
|
feature1=torch.flatten(feature1)
|
||||||
feature2=torch.flatten(feature2)
|
feature2=torch.flatten(feature2)
|
||||||
similarity = F.cosine_similarity(feature1,feature2)
|
similarity = F.cosine_similarity(feature1,feature2)
|
||||||
return similarity
|
return similarity
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path="./config", config_name="config")
|
||||||
|
def test(cfg):
|
||||||
|
prompt=torch.randn([1,1024]).to(device)
|
||||||
|
model=GanAttack(cfg,prompt).to(device)
|
||||||
|
data=torch.randn([2,3,256,256]).to(device)
|
||||||
|
result_images,x=model(data)
|
||||||
|
print(result_images.shape)
|
||||||
|
print(x.shape)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test()
|
||||||
|
|
||||||
Loading…
Reference in New Issue