load safetensors version in clip

This commit is contained in:
Geremie Yeo 2025-06-03 19:59:43 +08:00 committed by GitHub
parent d57daf938a
commit c2bdbb6ded
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 15 additions and 10 deletions

View File

@ -12,6 +12,7 @@ from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from safetensors import safe_open
try:
from torchvision.transforms import InterpolationMode
@ -124,16 +125,20 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if model_path.endswith('.safetensors'):
with safe_open(model_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
else:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)