add support for huggingface bin/safetensors variant
This commit is contained in:
parent
c2bdbb6ded
commit
1486984f3a
212
clip/clip.py
212
clip/clip.py
|
|
@ -91,6 +91,210 @@ def available_models() -> List[str]:
|
|||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
def convert_hf_to_openai_full(hf_state_dict):
|
||||
"""
|
||||
Complete conversion from Hugging Face CLIP to OpenAI CLIP format
|
||||
"""
|
||||
converted_dict = {}
|
||||
|
||||
for key, tensor in hf_state_dict.items():
|
||||
# Skip position_ids - they're not needed in OpenAI format
|
||||
if 'position_ids' in key:
|
||||
print(f"Skipping {key} - not needed in OpenAI format")
|
||||
continue
|
||||
|
||||
# Handle projection weights
|
||||
elif key == 'visual_projection.weight':
|
||||
# In OpenAI CLIP, this is stored as visual.proj (transposed)
|
||||
converted_dict['visual.proj'] = tensor.T # Note the transpose!
|
||||
continue
|
||||
|
||||
elif key == 'text_projection.weight':
|
||||
# In OpenAI CLIP, this is stored as text_projection (transposed)
|
||||
converted_dict['text_projection'] = tensor.T # Note the transpose!
|
||||
continue
|
||||
|
||||
# Handle other standard mappings
|
||||
elif key.startswith('text_model.'):
|
||||
if 'embeddings.token_embedding.weight' in key:
|
||||
converted_dict['token_embedding.weight'] = tensor
|
||||
elif 'embeddings.position_embedding.weight' in key:
|
||||
converted_dict['positional_embedding'] = tensor
|
||||
elif 'final_layer_norm.weight' in key:
|
||||
converted_dict['ln_final.weight'] = tensor
|
||||
elif 'final_layer_norm.bias' in key:
|
||||
converted_dict['ln_final.bias'] = tensor
|
||||
elif 'encoder.layers.' in key:
|
||||
converted_dict.update(convert_text_layer(key, tensor))
|
||||
|
||||
elif key.startswith('vision_model.'):
|
||||
if 'embeddings.patch_embedding.weight' in key:
|
||||
converted_dict['visual.conv1.weight'] = tensor
|
||||
elif 'embeddings.position_embedding.weight' in key:
|
||||
converted_dict['visual.positional_embedding'] = tensor
|
||||
elif 'embeddings.class_embedding' in key:
|
||||
converted_dict['visual.class_embedding'] = tensor
|
||||
elif 'pre_layrnorm.weight' in key:
|
||||
converted_dict['visual.ln_pre.weight'] = tensor
|
||||
elif 'pre_layrnorm.bias' in key:
|
||||
converted_dict['visual.ln_pre.bias'] = tensor
|
||||
elif 'post_layernorm.weight' in key:
|
||||
converted_dict['visual.ln_post.weight'] = tensor
|
||||
elif 'post_layernorm.bias' in key:
|
||||
converted_dict['visual.ln_post.bias'] = tensor
|
||||
elif 'encoder.layers.' in key:
|
||||
converted_dict.update(convert_vision_layer(key, tensor))
|
||||
|
||||
elif key == 'logit_scale':
|
||||
converted_dict['logit_scale'] = tensor
|
||||
|
||||
else:
|
||||
print(f"Unhandled key: {key}")
|
||||
|
||||
# Handle the q/k/v -> in_proj_weight conversion
|
||||
converted_dict = combine_qkv_projections_complete(hf_state_dict, converted_dict)
|
||||
|
||||
return converted_dict
|
||||
|
||||
def convert_text_layer(key, tensor):
|
||||
"""Convert text layer keys"""
|
||||
import re
|
||||
result = {}
|
||||
|
||||
layer_match = re.search(r'encoder\.layers\.(\d+)', key)
|
||||
if not layer_match:
|
||||
return result
|
||||
|
||||
layer_num = layer_match.group(1)
|
||||
|
||||
if 'self_attn.out_proj.weight' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor
|
||||
elif 'self_attn.out_proj.bias' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor
|
||||
elif 'layer_norm1.weight' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.ln_1.weight'] = tensor
|
||||
elif 'layer_norm1.bias' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.ln_1.bias'] = tensor
|
||||
elif 'layer_norm2.weight' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.ln_2.weight'] = tensor
|
||||
elif 'layer_norm2.bias' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.ln_2.bias'] = tensor
|
||||
elif 'mlp.fc1.weight' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor
|
||||
elif 'mlp.fc1.bias' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor
|
||||
elif 'mlp.fc2.weight' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor
|
||||
elif 'mlp.fc2.bias' in key:
|
||||
result[f'transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor
|
||||
# Skip q/k/v proj weights - handled separately
|
||||
|
||||
return result
|
||||
|
||||
def convert_vision_layer(key, tensor):
|
||||
"""Convert vision layer keys"""
|
||||
import re
|
||||
result = {}
|
||||
|
||||
layer_match = re.search(r'encoder\.layers\.(\d+)', key)
|
||||
if not layer_match:
|
||||
return result
|
||||
|
||||
layer_num = layer_match.group(1)
|
||||
|
||||
if 'self_attn.out_proj.weight' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.weight'] = tensor
|
||||
elif 'self_attn.out_proj.bias' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.attn.out_proj.bias'] = tensor
|
||||
elif 'layer_norm1.weight' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.ln_1.weight'] = tensor
|
||||
elif 'layer_norm1.bias' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.ln_1.bias'] = tensor
|
||||
elif 'layer_norm2.weight' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.ln_2.weight'] = tensor
|
||||
elif 'layer_norm2.bias' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.ln_2.bias'] = tensor
|
||||
elif 'mlp.fc1.weight' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.weight'] = tensor
|
||||
elif 'mlp.fc1.bias' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_fc.bias'] = tensor
|
||||
elif 'mlp.fc2.weight' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.weight'] = tensor
|
||||
elif 'mlp.fc2.bias' in key:
|
||||
result[f'visual.transformer.resblocks.{layer_num}.mlp.c_proj.bias'] = tensor
|
||||
# Skip q/k/v proj weights - handled separately
|
||||
|
||||
return result
|
||||
|
||||
def combine_qkv_projections_complete(hf_state_dict, converted_dict):
|
||||
"""Combine q, k, v projections for both text and vision models"""
|
||||
import re
|
||||
|
||||
# Process text model layers
|
||||
for key in hf_state_dict.keys():
|
||||
if 'text_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key:
|
||||
layer_match = re.search(r'layers\.(\d+)', key)
|
||||
if layer_match:
|
||||
layer_num = layer_match.group(1)
|
||||
|
||||
q_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.weight'
|
||||
k_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.weight'
|
||||
v_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.weight'
|
||||
|
||||
if all(k in hf_state_dict for k in [q_key, k_key, v_key]):
|
||||
combined_weight = torch.cat([
|
||||
hf_state_dict[q_key],
|
||||
hf_state_dict[k_key],
|
||||
hf_state_dict[v_key]
|
||||
], dim=0)
|
||||
converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight
|
||||
|
||||
# Handle biases if they exist
|
||||
q_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.q_proj.bias'
|
||||
k_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.k_proj.bias'
|
||||
v_bias_key = f'text_model.encoder.layers.{layer_num}.self_attn.v_proj.bias'
|
||||
|
||||
if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]):
|
||||
combined_bias = torch.cat([
|
||||
hf_state_dict[q_bias_key],
|
||||
hf_state_dict[k_bias_key],
|
||||
hf_state_dict[v_bias_key]
|
||||
], dim=0)
|
||||
converted_dict[f'transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias
|
||||
|
||||
# Process vision model layers
|
||||
for key in hf_state_dict.keys():
|
||||
if 'vision_model.encoder.layers.' in key and 'self_attn.q_proj.weight' in key:
|
||||
layer_match = re.search(r'layers\.(\d+)', key)
|
||||
if layer_match:
|
||||
layer_num = layer_match.group(1)
|
||||
|
||||
q_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.weight'
|
||||
k_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.weight'
|
||||
v_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.weight'
|
||||
|
||||
if all(k in hf_state_dict for k in [q_key, k_key, v_key]):
|
||||
combined_weight = torch.cat([
|
||||
hf_state_dict[q_key],
|
||||
hf_state_dict[k_key],
|
||||
hf_state_dict[v_key]
|
||||
], dim=0)
|
||||
converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_weight'] = combined_weight
|
||||
|
||||
# Handle biases
|
||||
q_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.q_proj.bias'
|
||||
k_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.k_proj.bias'
|
||||
v_bias_key = f'vision_model.encoder.layers.{layer_num}.self_attn.v_proj.bias'
|
||||
|
||||
if all(k in hf_state_dict for k in [q_bias_key, k_bias_key, v_bias_key]):
|
||||
combined_bias = torch.cat([
|
||||
hf_state_dict[q_bias_key],
|
||||
hf_state_dict[k_bias_key],
|
||||
hf_state_dict[v_bias_key]
|
||||
], dim=0)
|
||||
converted_dict[f'visual.transformer.resblocks.{layer_num}.attn.in_proj_bias'] = combined_bias
|
||||
|
||||
return converted_dict
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
|
@ -125,7 +329,9 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
if model_path.endswith('.safetensors'):
|
||||
if model_path.endswith('.bin'):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
elif model_path.endswith('.safetensors'):
|
||||
with safe_open(model_path, framework="pt", device="cpu") as f:
|
||||
state_dict = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
|
|
@ -140,6 +346,10 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
|
||||
if model_path.endswith('.bin') or model_path.endswith('.safetensors'):
|
||||
state_dict = convert_hf_to_openai_full(state_dict)
|
||||
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
|
|
|
|||
Loading…
Reference in New Issue