fix utf8 username on windows (#227)

This commit is contained in:
liuxingbaoyu 2022-04-11 05:07:46 +08:00 committed by GitHub
parent c0065a27ad
commit 3482bb6ed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 10 deletions

View File

@ -122,16 +122,17 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
else: else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}") raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try: with open(model_path, 'rb') as opened_file:
# loading JIT archive try:
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() # loading JIT archive
state_dict = None model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
except RuntimeError: state_dict = None
# loading saved state dict except RuntimeError:
if jit: # loading saved state dict
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") if jit:
jit = False warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
state_dict = torch.load(model_path, map_location="cpu") jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit: if not jit:
model = build_model(state_dict or model.state_dict()).to(device) model = build_model(state_dict or model.state_dict()).to(device)