[plugin] support all-gather overlap for hybrid parallel (#5919)

* [plugin] fixed all-gather overlap support for hybrid parallel
This commit is contained in:
Hongxin Liu
2024-07-18 15:33:03 +08:00
committed by GitHub
parent 73494de577
commit e86127925a
4 changed files with 42 additions and 12 deletions

View File

@@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if os.path.isfile(checkpoint):
@@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
@@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
model = model.unwrap()
if self.dp_rank != 0:
@@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model._force_wait_all_gather()
strict = False
model_before_wrapping = model
model = model.unwrap()