mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[checkpointio] optimize zero optim checkpoint io (#4591)
* [zero] update checkpoint io to save memory * [checkpointio] add device map to save memory
This commit is contained in:
@@ -78,8 +78,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
|
@@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
|
||||
return safe_load_file(checkpoint_file)
|
||||
else:
|
||||
return torch.load(checkpoint_file)
|
||||
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
|
||||
|
||||
|
||||
def load_state_dict_into_model(model: nn.Module,
|
||||
@@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
|
||||
|
||||
# Load list of param_groups from given file path.
|
||||
# The params in saved_groups are in the form of integer indices.
|
||||
saved_groups = torch.load(param_group_path)
|
||||
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
|
||||
if not isinstance(saved_groups, List):
|
||||
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
|
||||
|
||||
@@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path):
|
||||
|
||||
else:
|
||||
# load with torch
|
||||
return torch.load(checkpoint_file_path)
|
||||
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
|
||||
|
||||
|
||||
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str:
|
||||
|
Reference in New Issue
Block a user