fix utf8 username on windows (#227)
This commit is contained in:
parent
c0065a27ad
commit
3482bb6ed3
|
|
@ -122,16 +122,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||||
|
|
||||||
|
with open(model_path, 'rb') as opened_file:
|
||||||
try:
|
try:
|
||||||
# loading JIT archive
|
# loading JIT archive
|
||||||
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||||
state_dict = None
|
state_dict = None
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# loading saved state dict
|
# loading saved state dict
|
||||||
if jit:
|
if jit:
|
||||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||||
jit = False
|
jit = False
|
||||||
state_dict = torch.load(model_path, map_location="cpu")
|
state_dict = torch.load(opened_file, map_location="cpu")
|
||||||
|
|
||||||
if not jit:
|
if not jit:
|
||||||
model = build_model(state_dict or model.state_dict()).to(device)
|
model = build_model(state_dict or model.state_dict()).to(device)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue