[checkpointio] optimize zero optim checkpoint io (#4591)

* [zero] update checkpoint io to save memory

* [checkpointio] add device map to save memory
This commit is contained in:
Hongxin Liu
2023-09-04 11:26:45 +08:00
committed by GitHub
parent cfa607080f
commit 63ecafb1fb
4 changed files with 43 additions and 22 deletions

View File

@@ -78,8 +78,6 @@ class GeneralCheckpointIO(CheckpointIO):
for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
load_states_into_optimizer(optimizer, state_dict, id_map)
del state_dict
gc.collect()
sharded_optimizer_loading_epilogue(optimizer)

View File

@@ -237,7 +237,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.")
return safe_load_file(checkpoint_file)
else:
return torch.load(checkpoint_file)
return torch.load(checkpoint_file, map_location=torch.device('cpu'))
def load_state_dict_into_model(model: nn.Module,
@@ -297,7 +297,7 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
saved_groups = torch.load(param_group_path)
saved_groups = torch.load(param_group_path, map_location=torch.device('cpu'))
if not isinstance(saved_groups, List):
raise ValueError(f'The param_groups saved at {param_group_path} is not of List type')
@@ -608,7 +608,7 @@ def load_state_dict(checkpoint_file_path: Path):
else:
# load with torch
return torch.load(checkpoint_file_path)
return torch.load(checkpoint_file_path, map_location=torch.device('cpu'))
def add_prefix(weights_name: str, prefix: Optional[str] = None) -> str: