Fix torch._C.Node attribute access (#372)
Attribute access with subscripting would previously work due to patching in https://github.com/pytorch/pytorch/pull/82511 but this has been removed. This commit uses the fix proposed in https://github.com/pytorch/pytorch/pull/82628 to define a helper method to call the appropriate access method.
This commit is contained in:
parent
a9b1bf5920
commit
a1d071733d
12
clip/clip.py
12
clip/clip.py
|
|
@ -145,6 +145,14 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||||
|
|
||||||
|
def _node_get(node: torch._C.Node, key: str):
|
||||||
|
"""Gets attributes of a node which is polymorphic over return type.
|
||||||
|
|
||||||
|
From https://github.com/pytorch/pytorch/pull/82628
|
||||||
|
"""
|
||||||
|
sel = node.kindOf(key)
|
||||||
|
return getattr(node, sel)(key)
|
||||||
|
|
||||||
def patch_device(module):
|
def patch_device(module):
|
||||||
try:
|
try:
|
||||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||||
|
|
@ -156,7 +164,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
|
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
for node in graph.findAllNodes("prim::Constant"):
|
for node in graph.findAllNodes("prim::Constant"):
|
||||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
||||||
node.copyAttributes(device_node)
|
node.copyAttributes(device_node)
|
||||||
|
|
||||||
model.apply(patch_device)
|
model.apply(patch_device)
|
||||||
|
|
@ -182,7 +190,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
for node in graph.findAllNodes("aten::to"):
|
for node in graph.findAllNodes("aten::to"):
|
||||||
inputs = list(node.inputs())
|
inputs = list(node.inputs())
|
||||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||||
if inputs[i].node()["value"] == 5:
|
if _node_get(inputs[i].node(), "value") == 5:
|
||||||
inputs[i].node().copyAttributes(float_node)
|
inputs[i].node().copyAttributes(float_node)
|
||||||
|
|
||||||
model.apply(patch_float)
|
model.apply(patch_float)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue