GanAttack/prompt.py

36 lines
996 B
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from omegaconf import DictConfig, OmegaConf
from data.dataset import get_dataset,get_adv_dataset
from utils import get_model,set_requires_grad,unnormalize
import sys
import os
import clip
import torch.nn.functional as F
import time
import hydra
from scipy import io as spio
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@hydra.main(version_base=None, config_path="./config", config_name="config")
def save_prompt(cfg):
model, preprocess = clip.load("RN50", device=device)
prompt=cfg.prompt
text=clip.tokenize(prompt).to(device)
with torch.no_grad():
prompt = model.encode_text(text)
prompt=prompt.cpu().numpy()
spio.savemat("prompt.mat",{'prompt':prompt})
def get_prompt(cfg):
data=spio.loadmat("prompt.mat")
prompt=data['prompt']
return torch.from_numpy(prompt).to(device)
if __name__ == "__main__":
save_prompt()