Add Intel Gaudi HPU device usage
This commit is contained in:
parent
dcba3cb2e2
commit
5393fdde9c
52
clip/clip.py
52
clip/clip.py
|
|
@ -12,9 +12,11 @@ from tqdm import tqdm
|
|||
|
||||
from .model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
from .utils import get_device_initial
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
|
@ -51,13 +53,24 @@ def _download(url: str, root: str):
|
|||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
if (
|
||||
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
|
||||
== expected_sha256
|
||||
):
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
|
|
@ -91,7 +104,12 @@ def available_models() -> List[str]:
|
|||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
def load(
|
||||
name: str,
|
||||
device: Union[str, torch.device] = get_device_initial(),
|
||||
jit: bool = False,
|
||||
download_root: str = None,
|
||||
):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
|
|
@ -100,7 +118,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
The device to put the loaded model, by default it uses the device returned by `clip.get_device_initial()`
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
|
@ -123,10 +141,12 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
with open(model_path, "rb") as opened_file:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
model = torch.jit.load(
|
||||
opened_file, map_location=device if jit else "cpu"
|
||||
).eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
|
|
@ -171,9 +191,11 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
# patch dtype to float32 on CPU, HPU
|
||||
if str(device) in ["cpu", "hpu"]:
|
||||
float_holder = torch.jit.trace(
|
||||
lambda: torch.ones([]).float(), example_inputs=[]
|
||||
)
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
|
|
@ -199,10 +221,18 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
|||
|
||||
model.float()
|
||||
|
||||
if str(device) == "hpu":
|
||||
if torch.hpu.is_available():
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
model = wrap_in_hpu_graph(model)
|
||||
model = model.eval().to(torch.device(device))
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||
def tokenize(
|
||||
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
|
||||
) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
import importlib.util
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_device_initial(preferred_device=None):
|
||||
"""
|
||||
Determine the appropriate device to use (cuda, hpu, or cpu).
|
||||
Args:
|
||||
preferred_device (str): User-preferred device ('cuda', 'hpu', or 'cpu').
|
||||
|
||||
Returns:
|
||||
str: Device string ('cuda', 'hpu', or 'cpu').
|
||||
"""
|
||||
# Check for HPU support
|
||||
if importlib.util.find_spec("habana_frameworks") is not None:
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
|
||||
load_habana_module()
|
||||
if torch.hpu.is_available():
|
||||
if preferred_device == "hpu" or preferred_device is None:
|
||||
return "hpu"
|
||||
|
||||
# Check for CUDA (GPU support)
|
||||
if torch.cuda.is_available():
|
||||
if preferred_device == "cuda" or preferred_device is None:
|
||||
return "cuda"
|
||||
|
||||
# Default to CPU
|
||||
return "cpu"
|
||||
Loading…
Reference in New Issue