Fix JIT error

This commit is contained in:
PiotrBLL 2024-12-19 15:24:38 +01:00
parent e8d6206c16
commit d17b83174d
2 changed files with 15 additions and 3 deletions

View File

@ -156,13 +156,25 @@ def load(
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
model = build_model(state_dict or model.state_dict())
if str(device) == "hpu":
from habana_frameworks.torch.utils.library_loader import load_habana_module
load_habana_module()
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
else:
model = model.to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device("cpu" if device == "hpu" else device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def _node_get(node: torch._C.Node, key: str):

View File

@ -28,7 +28,7 @@ def test_consistency(model_name):
@pytest.mark.parametrize("model_name", clip.available_models())
def test_hpu_support(model_name):
device = "hpu"
jit_model, transform = clip.load(model_name, device=device, jit=True)
jit_model, transform = clip.load(model_name, device="cpu", jit=True)
py_model, _ = clip.load(model_name, device=device, jit=False)
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)