mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-20 17:10:03 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
@@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: GeminiOptimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
|
Reference in New Issue
Block a user