mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Merge branch 'main' into feature/shardformer
This commit is contained in:
@@ -514,7 +514,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
@@ -574,7 +574,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
@@ -730,7 +730,7 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||
|
||||
else:
|
||||
# load with torch
|
||||
return torch.load(checkpoint_file_path)
|
||||
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
|
||||
|
||||
|
||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||
|
Reference in New Issue
Block a user