add support for huggingface bin/safetensors variant

This commit is contained in:
bogoconic1 2025-06-03 15:00:21 +00:00
parent c2bdbb6ded
commit 1486984f3a
1 changed files with 211 additions and 1 deletions

View File

@ -91,6 +91,210 @@ def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def convert_hf_to_openai_full(hf_state_dict):
"""
Complete conversion from Hugging Face CLIP to OpenAI CLIP format
"""
converted_dict = {}
for key, tensor in hf_state_dict.items():
# Skip position_ids - they're not needed in OpenAI format
if 'position_ids' in key:
print(f"Skipping {key} - not needed in OpenAI format")
continue
# Handle projection weights
elif key == 'visual_projection.weight':
# In OpenAI CLIP, this is stored as visual.proj (transposed)
converted_dict['visual.proj'] = tensor.T # Note the transpose!
continue
elif key == 'text_projection.weight':
# In OpenAI CLIP, this is stored as text_projection (transposed)
converted_dict['text_projection'] = tensor.T # Note the transpose!
continue
# Handle other standard mappings
elif key.startswith('text_model.'):
if 'embeddings.token_embedding.weight' in key:
converted_dict['token_embedding.weight'] = tensor
elif 'embeddings.position_embedding.weight' in key:
converted_dict['positional_embedding'] = tensor
elif 'final_layer_norm.weight' in key:
converted_dict['ln_final.weight'] = tensor
elif 'final_layer_norm.bias' in key:
converted_dict['ln_final.bias'] = tensor
elif 'encoder.layers.' in key:
converted_dict.update(convert_text_layer(key, tensor))
elif key.startswith('vision_model.'):
if 'embeddings.patch_embedding.weight' in key:
converted_dict['visual.conv1.weight'] = tensor
elif 'embeddings.position_embedding.weight' in key:
converted_dict['visual.positional_embedding'] = tensor
elif 'embeddings.class_embedding' in key:
converted_dict['visual.class_embedding'] = tensor
elif 'pre_layrnorm.weight' in key:
converted_dict['visual.ln_pre.weight'] = tensor
elif 'pre_layrnorm.bias' in key:
converted_dict['visual.ln_pre.bias'] = tensor
elif 'post_layernorm.weight' in key:
converted_dict['visual.ln_post.weight'] = tensor
elif 'post_layernorm.bias' in key:
converted_dict['visual.ln_post.bias'] = tensor
elif 'encoder.layers.' in key:
converted_dict.update(convert_vision_layer(key, tensor))
elif key == 'logit_scale':
converted_dict['logit_scale'] = tensor
else:
print(f"Unhandled key: {key}")
# Handle the q/k/v -> in_proj_weight conversion
converted_dict = combine_qkv_projections_complete(hf_state_dict, converted_dict)
return converted_dict
def convert_text_layer(key, tensor):
"""Convert text layer keys"""
import re
result = {}
layer_match = re.search(r'encoder\.layers\.(\d+)', key)
if not layer_match:
return result
layer_num = layer_match.group(1)
if 'self_attn.out_proj.weight' in key:
result[f'transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor
elif 'self_attn.out_proj.bias' in key:
result[f'transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor
elif 'layer_norm1.weight' in key:
result[f'transformer.resblocks.{layer_num}.ln_1.weight'] = tensor
elif 'layer_norm1.bias' in key:
result[f'transformer.resblocks.{layer_num}.ln_1.bias'] = tensor
elif 'layer_norm2.weight' in key:
result[f'transformer.resblocks.{layer_num}.ln_2.weight'] = tensor
elif 'layer_norm2.bias' in key:
result[f'transformer.resblocks.{layer_num}.ln_2.bias'] = tensor
elif 'mlp.fc1.weight' in key:
result[f'transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor
elif 'mlp.fc1.bias' in key:
result[f'transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor
elif 'mlp.fc2.weight' in key:
result[f'transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor
elif 'mlp.fc2.bias' in key:
result[f'transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor
# Skip q/k/v proj weights - handled separately
return result
def convert_vision_layer(key, tensor):
"""Convert vision layer keys"""
import re
result = {}
layer_match = re.search(r'encoder\.layers\.(\d+)', key)
if not layer_match:
return result
layer_num = layer_match.group(1)
if 'self_attn.out_proj.weight' in key:
result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor
elif 'self_attn.out_proj.bias' in key:
result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor
elif 'layer_norm1.weight' in key:
result[f'visual.transformer.resblocks.{layer_num}.ln_1.weight'] = tensor
elif 'layer_norm1.bias' in key:
result[f'visual.transformer.resblocks.{layer_num}.ln_1.bias'] = tensor
elif 'layer_norm2.weight' in key:
result[f'visual.transformer.resblocks.{layer_num}.ln_2.weight'] = tensor
elif 'layer_norm2.bias' in key:
result[f'visual.transformer.resblocks.{layer_num}.ln_2.bias'] = tensor
elif 'mlp.fc1.weight' in key:
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor
elif 'mlp.fc1.bias' in key:
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor
elif 'mlp.fc2.weight' in key:
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor
elif 'mlp.fc2.bias' in key:
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor
# Skip q/k/v proj weights - handled separately
return result
def combine_qkv_projections_complete(hf_state_dict, converted_dict):
"""Combine q, k, v projections for both text and vision models"""
import re
# Process text model layers
for key in hf_state_dict.keys():
if 'text_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key:
layer_match = re.search(r'layers\.(\d+)', key)
if layer_match:
layer_num = layer_match.group(1)
q_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.weight'
k_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.weight'
v_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.weight'
if all(k in hf_state_dict for k in [q_key, k_key, v_key]):
combined_weight = torch.cat([
hf_state_dict[q_key],
hf_state_dict[k_key],
hf_state_dict[v_key]
], dim=0)
converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight
# Handle biases if they exist
q_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.bias'
k_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.bias'
v_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.bias'
if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]):
combined_bias = torch.cat([
hf_state_dict[q_bias_key],
hf_state_dict[k_bias_key],
hf_state_dict[v_bias_key]
], dim=0)
converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias
# Process vision model layers
for key in hf_state_dict.keys():
if 'vision_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key:
layer_match = re.search(r'layers\.(\d+)', key)
if layer_match:
layer_num = layer_match.group(1)
q_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.weight'
k_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.weight'
v_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.weight'
if all(k in hf_state_dict for k in [q_key, k_key, v_key]):
combined_weight = torch.cat([
hf_state_dict[q_key],
hf_state_dict[k_key],
hf_state_dict[v_key]
], dim=0)
converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight
# Handle biases
q_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.bias'
k_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.bias'
v_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.bias'
if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]):
combined_bias = torch.cat([
hf_state_dict[q_bias_key],
hf_state_dict[k_bias_key],
hf_state_dict[v_bias_key]
], dim=0)
converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias
return converted_dict
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
@ -125,7 +329,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
if model_path.endswith('.safetensors'):
if model_path.endswith('.bin'):
state_dict = torch.load(model_path, map_location="cpu")
elif model_path.endswith('.safetensors'):
with safe_open(model_path, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
else:
@ -140,6 +346,10 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if model_path.endswith('.bin') or model_path.endswith('.safetensors'):
state_dict = convert_hf_to_openai_full(state_dict)
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":