importing the version module

This commit is contained in:
Jong Wook Kim 2024-06-04 01:09:22 +09:00
parent e5693f4006
commit 3e6026fd41
2 changed files with 5 additions and 4 deletions

View File

@ -2,8 +2,8 @@ import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
import packaging
from packaging import version
from typing import Union, List
import torch
from PIL import Image
@ -20,7 +20,7 @@ except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
if version.parse(torch.__version__) < version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
@ -228,7 +228,7 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
if version.parse(torch.__version__) < version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

View File

@ -1,4 +1,5 @@
ftfy
packaging
regex
tqdm
torch