mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[shardformer] support ep for deepseek v3 (#6185)
* [feature] support ep for deepseek v3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * [shardformer] fix deepseek v3 init * [lazy] fit lora for lazy init * [example] support npu for deepseek v3 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -19,7 +19,6 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
||||
HybridParallelPlugin,
|
||||
HybridParallelZeroOptimizer,
|
||||
get_param_info,
|
||||
reinitialize_optimizer,
|
||||
)
|
||||
from colossalai.checkpoint_io import MoECheckpointIO
|
||||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
@@ -468,18 +467,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.ep_size > 1:
|
||||
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
|
||||
# but the optimizer is not aware of ep, so we need to update the optimizer
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
if self.zero_stage == 0:
|
||||
is_zero = False
|
||||
if self.precision in ["fp16", "bf16"]:
|
||||
optimizer = HybridParallelAMPOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
precision=self.precision,
|
||||
max_norm=self.max_norm,
|
||||
@@ -489,7 +483,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
max_norm=self.max_norm,
|
||||
pp_process_group=self.pp_group,
|
||||
@@ -507,7 +501,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
optimizer = MoeHybridParallelZeroOptimizer(
|
||||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.mixed_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
|
Reference in New Issue
Block a user