Merge branch 'main' into feature/shardformer

This commit is contained in:
Hongxin Liu
2023-09-04 23:43:13 +08:00
committed by GitHub
138 changed files with 4664 additions and 4219 deletions

View File

@@ -514,7 +514,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module,
@@ -574,7 +574,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path)
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
@@ -730,7 +730,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: