Fix JIT error
This commit is contained in:
parent
e8d6206c16
commit
d17b83174d
16
clip/clip.py
16
clip/clip.py
|
|
@ -156,13 +156,25 @@ def load(
|
|||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
model = build_model(state_dict or model.state_dict())
|
||||
|
||||
if str(device) == "hpu":
|
||||
from habana_frameworks.torch.utils.library_loader import load_habana_module
|
||||
|
||||
load_habana_module()
|
||||
if torch.hpu.is_available():
|
||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||
|
||||
model = wrap_in_hpu_graph(model)
|
||||
model = model.eval().to(torch.device(device))
|
||||
else:
|
||||
model = model.to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device("cpu" if device == "hpu" else device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def _node_get(node: torch._C.Node, key: str):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def test_consistency(model_name):
|
|||
@pytest.mark.parametrize("model_name", clip.available_models())
|
||||
def test_hpu_support(model_name):
|
||||
device = "hpu"
|
||||
jit_model, transform = clip.load(model_name, device=device, jit=True)
|
||||
jit_model, transform = clip.load(model_name, device="cpu", jit=True)
|
||||
py_model, _ = clip.load(model_name, device=device, jit=False)
|
||||
|
||||
image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
|
||||
|
|
|
|||
Loading…
Reference in New Issue