mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
* Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update build_on_pr.yml * Update test_zerobubble_pp.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.dp_size = get_world_size(self.mixed_dp_group)
|
||||
else:
|
||||
self.mixed_dp_group = self.dp_group
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
@@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1 and self.pp_size == 1
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.dp_size = get_world_size(dp_group)
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=dp_group,
|
||||
dp_group=self.mixed_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
@@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=dp_group,
|
||||
dp_process_group=self.mixed_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
verbose=True,
|
||||
@@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
|
||||
return HybridParallelCheckpointIO(
|
||||
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
|
||||
)
|
||||
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
|
Reference in New Issue
Block a user