mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
This commit is contained in:
@@ -195,6 +195,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
@@ -303,6 +304,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
model_before_wrapping = model # backup for model before wrapping
|
||||
model = model.unwrap()
|
||||
|
||||
@@ -639,6 +641,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
model = model.unwrap()
|
||||
|
||||
if self.dp_rank != 0:
|
||||
@@ -679,6 +682,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model._force_wait_all_gather()
|
||||
strict = False
|
||||
model_before_wrapping = model
|
||||
model = model.unwrap()
|
||||
|
Reference in New Issue
Block a user