mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-30 12:45:33 +00:00
[hotfix] fix hybrid checkpointio for sp+dp (#6184)
* Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * Update hybrid_parallel_plugin.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update build_on_pr.yml * Update test_zerobubble_pp.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ca0aa2365d
commit
17062c83b9
.github/workflows
colossalai/booster/plugin
tests
test_pipeline/test_schedule
test_shardformer/test_model
2
.github/workflows/build_on_pr.yml
vendored
2
.github/workflows/build_on_pr.yml
vendored
@ -199,7 +199,7 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Upload test coverage artifact
|
- name: Upload test coverage artifact
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: report
|
name: report
|
||||||
path: report/
|
path: report/
|
||||||
|
@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
# Apply Hybrid ZeRO across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||||
|
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||||
|
self.dp_size = get_world_size(self.mixed_dp_group)
|
||||||
|
else:
|
||||||
|
self.mixed_dp_group = self.dp_group
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
sequence_parallel_process_group=self.sp_group,
|
sequence_parallel_process_group=self.sp_group,
|
||||||
@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||||
self.dp_size == 1 and self.pp_size == 1
|
self.dp_size == 1 and self.pp_size == 1
|
||||||
)
|
)
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
# Apply Hybrid ZeRO across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
|
||||||
self.dp_size = get_world_size(dp_group)
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
model = HybridParallelModule(
|
model = HybridParallelModule(
|
||||||
model,
|
model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=dp_group,
|
dp_group=self.mixed_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=dp_group,
|
dp_process_group=self.mixed_dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
|
return HybridParallelCheckpointIO(
|
||||||
|
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
|
||||||
|
)
|
||||||
|
|
||||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert (
|
assert (
|
||||||
|
@ -351,6 +351,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
|
||||||
else:
|
else:
|
||||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||||
|
|
||||||
|
# sync gradients across DP * SP ranks
|
||||||
|
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
||||||
|
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
||||||
|
self.dp_size = dist.get_world_size(self.mixed_dp_group)
|
||||||
|
else:
|
||||||
|
self.mixed_dp_group = self.dp_group
|
||||||
|
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
@ -404,7 +412,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
self.dp_group,
|
self.mixed_dp_group,
|
||||||
self.pp_group,
|
self.pp_group,
|
||||||
self.tp_group,
|
self.tp_group,
|
||||||
self.sp_group,
|
self.sp_group,
|
||||||
@ -435,12 +443,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
and self.sequence_parallelism_mode == "all_to_all"
|
and self.sequence_parallelism_mode == "all_to_all"
|
||||||
)
|
)
|
||||||
|
|
||||||
# sync gradients across DP * SP ranks
|
|
||||||
if self.enable_sequence_parallelism and self.sequence_parallelism_mode == "all_to_all":
|
|
||||||
dp_group = self.pg_mesh.create_group_along_axis([self.moe_dp_axis, self.ep_axis, self.sp_axis])
|
|
||||||
else:
|
|
||||||
dp_group = self.dp_group
|
|
||||||
|
|
||||||
if use_ddp:
|
if use_ddp:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
f"Will have to check all params are used in pytorch DDP since not all experts are always activated",
|
||||||
@ -448,7 +450,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
)
|
)
|
||||||
self.ddp_config["find_unused_parameters"] = True
|
self.ddp_config["find_unused_parameters"] = True
|
||||||
|
|
||||||
if dist.get_process_group_ranks(dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
if dist.get_process_group_ranks(self.mixed_dp_group) != dist.get_process_group_ranks(self.moe_dp_group):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
|
f"if pytorch DDP is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group {dist.get_process_group_ranks(dp_group)} and moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)}, you might want to modify your config to bypass DDP \nhint: check the above ddp condition to by pass this"
|
||||||
)
|
)
|
||||||
@ -457,7 +459,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
module=model,
|
module=model,
|
||||||
precision=self.precision,
|
precision=self.precision,
|
||||||
shard_config=self.shard_config,
|
shard_config=self.shard_config,
|
||||||
dp_group=dp_group,
|
dp_group=self.mixed_dp_group,
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
sp_group=self.sp_group,
|
sp_group=self.sp_group,
|
||||||
use_ddp=use_ddp,
|
use_ddp=use_ddp,
|
||||||
@ -507,7 +509,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=dp_group,
|
dp_process_group=self.mixed_dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
pp_process_group=self.pp_group,
|
pp_process_group=self.pp_group,
|
||||||
moe_dp_group=self.moe_dp_group,
|
moe_dp_group=self.moe_dp_group,
|
||||||
|
@ -885,12 +885,12 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
@ -1040,12 +1040,12 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
|
@ -125,12 +125,12 @@ def run_deepseek_commom(parallel_config: Tuple[int, ...]):
|
|||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
|
@ -118,12 +118,12 @@ def run_mixtral_commom(config: Tuple[int, ...]):
|
|||||||
parallel_optimizer.backward(parallel_output)
|
parallel_optimizer.backward(parallel_output)
|
||||||
parallel_optimizer.step()
|
parallel_optimizer.step()
|
||||||
parallel_optimizer.zero_grad()
|
parallel_optimizer.zero_grad()
|
||||||
dist.all_reduce(parallel_output, group=plugin.dp_group)
|
dist.all_reduce(parallel_output, group=plugin.mixed_dp_group)
|
||||||
|
|
||||||
# ===================================================================================
|
# ===================================================================================
|
||||||
# run normal model with all dp(different) inputs
|
# run normal model with all dp(different) inputs
|
||||||
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
all_inputs = [torch.empty_like(input_embeddings) for _ in range(dp_size)]
|
||||||
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.mixed_dp_group)
|
||||||
torch_output_sum = 0
|
torch_output_sum = 0
|
||||||
for input_data_ in all_inputs:
|
for input_data_ in all_inputs:
|
||||||
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
||||||
|
Loading…
Reference in New Issue
Block a user