Patch clip model for ONNX compatibility (#219)
* Patch clip model for ONNX compatibility Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64) Use explicit dimension for norm * Add compatibility fix for torch 1.7
This commit is contained in:
parent
40f5484c1c
commit
7ef63f265b
10
clip/clip.py
10
clip/clip.py
|
|
@ -192,7 +192,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
return model, _transform(model.input_resolution.item())
|
return model, _transform(model.input_resolution.item())
|
||||||
|
|
||||||
|
|
||||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||||
"""
|
"""
|
||||||
Returns the tokenized representation of given input string(s)
|
Returns the tokenized representation of given input string(s)
|
||||||
|
|
||||||
|
|
@ -209,7 +209,8 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
||||||
|
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
||||||
"""
|
"""
|
||||||
if isinstance(texts, str):
|
if isinstance(texts, str):
|
||||||
texts = [texts]
|
texts = [texts]
|
||||||
|
|
@ -217,7 +218,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
|
||||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
||||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
||||||
|
|
||||||
for i, tokens in enumerate(all_tokens):
|
for i, tokens in enumerate(all_tokens):
|
||||||
if len(tokens) > context_length:
|
if len(tokens) > context_length:
|
||||||
|
|
|
||||||
|
|
@ -356,8 +356,8 @@ class CLIP(nn.Module):
|
||||||
text_features = self.encode_text(text)
|
text_features = self.encode_text(text)
|
||||||
|
|
||||||
# normalized features
|
# normalized features
|
||||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||||
|
|
||||||
# cosine similarity as logits
|
# cosine similarity as logits
|
||||||
logit_scale = self.logit_scale.exp()
|
logit_scale = self.logit_scale.exp()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue