mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[fix] rm use_zbv flag in Shardconfig; rm debug info;
This commit is contained in:
@@ -1201,7 +1201,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
inner_ring_size=inner_ring_size,
|
||||
use_zbv=(pp_style == "zbv"),
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@@ -373,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
use_zbv=(pp_style == "zbv"),
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
@@ -60,6 +60,11 @@ class LlamaPolicy(Policy):
|
||||
else:
|
||||
norm_cls = RMSNorm
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
|
||||
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||
sp_size = self.shard_config.sequence_parallel_size or None
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
@@ -129,7 +134,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -138,7 +143,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -147,7 +152,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -156,7 +161,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -165,7 +170,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -174,7 +179,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -183,7 +188,7 @@ class LlamaPolicy(Policy):
|
||||
kwargs=dict(
|
||||
seq_parallel_mode=sp_mode,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -413,6 +418,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
from transformers import LlamaForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
@@ -425,6 +434,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@@ -52,6 +52,10 @@ class MixtralPolicy(Policy):
|
||||
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||
tp_size = self.shard_config.tensor_parallel_size
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
|
||||
# modified for both SP and TP
|
||||
num_q_heads = self.model.config.num_attention_heads
|
||||
@@ -126,7 +130,7 @@ class MixtralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -134,7 +138,7 @@ class MixtralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -142,7 +146,7 @@ class MixtralPolicy(Policy):
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -150,7 +154,7 @@ class MixtralPolicy(Policy):
|
||||
target_module=Linear1D_Row,
|
||||
kwargs={
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
@@ -159,7 +163,7 @@ class MixtralPolicy(Policy):
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
),
|
||||
],
|
||||
@@ -195,7 +199,7 @@ class MixtralPolicy(Policy):
|
||||
"tp_group": self.shard_config.tensor_parallel_process_group,
|
||||
"moe_dp_group": self.shard_config.moe_dp_group,
|
||||
"fp8_communication": self.shard_config.fp8_communication,
|
||||
"use_zbv": self.shard_config.use_zbv,
|
||||
"use_zbv": use_zbv,
|
||||
},
|
||||
)
|
||||
],
|
||||
@@ -330,6 +334,10 @@ class MixtralModelPolicy(MixtralPolicy):
|
||||
class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
# TODO: assign pg mesh from plugin to all modules
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for causal lm
|
||||
@@ -342,7 +350,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
],
|
||||
@@ -392,6 +400,10 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
from transformers import MixtralForSequenceClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
use_zbv = self.pipeline_stage_manager.use_zbv
|
||||
else:
|
||||
use_zbv = False
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for sequence classification
|
||||
@@ -404,7 +416,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
fp8_communication=self.shard_config.fp8_communication,
|
||||
use_zbv=self.shard_config.use_zbv,
|
||||
use_zbv=use_zbv,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
@@ -49,7 +49,6 @@ class ShardConfig:
|
||||
make_vocab_size_divisible_by: int = 64
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
use_zbv: bool = False
|
||||
|
||||
# For ring attention
|
||||
inner_ring_size: Optional[int] = None
|
||||
|
Reference in New Issue
Block a user