load safetensors version in clip
This commit is contained in:
parent
d57daf938a
commit
c2bdbb6ded
25
clip/clip.py
25
clip/clip.py
|
|
@ -12,6 +12,7 @@ from tqdm import tqdm
|
|||
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from safetensors import safe_open
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
|
@ -124,16 +125,20 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
if model_path.endswith('.safetensors'):
|
||||
with safe_open(model_path, framework="pt", device="cpu") as f:
|
||||
state_dict = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue