[fix] rm use_zbv flag in Shardconfig; rm debug info;

This commit is contained in:
duanjunwen
2024-10-16 03:25:04 +00:00
parent 90939b77e0
commit e76308c6e6
9 changed files with 212 additions and 651 deletions

View File

@@ -1201,7 +1201,6 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
inner_ring_size=inner_ring_size,
use_zbv=(pp_style == "zbv"),
)
self.amp_config = dict(
initial_scale=initial_scale,

View File

@@ -373,7 +373,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
use_zbv=(pp_style == "zbv"),
)
self.amp_config = dict(
initial_scale=initial_scale,

View File

@@ -60,6 +60,11 @@ class LlamaPolicy(Policy):
else:
norm_cls = RMSNorm
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
sp_mode = self.shard_config.sequence_parallelism_mode or None
sp_size = self.shard_config.sequence_parallel_size or None
sp_group = self.shard_config.sequence_parallel_process_group or None
@@ -129,7 +134,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -138,7 +143,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -147,7 +152,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -156,7 +161,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -165,7 +170,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -174,7 +179,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
SubModuleReplacementDescription(
@@ -183,7 +188,7 @@ class LlamaPolicy(Policy):
kwargs=dict(
seq_parallel_mode=sp_mode,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
),
],
@@ -413,6 +418,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@@ -425,6 +434,7 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=use_zbv,
),
)
]

View File

@@ -52,6 +52,10 @@ class MixtralPolicy(Policy):
sp_group = self.shard_config.sequence_parallel_process_group or None
sp_partial_derived = sp_mode in ["split_gather", "ring"]
tp_size = self.shard_config.tensor_parallel_size
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# modified for both SP and TP
num_q_heads = self.model.config.num_attention_heads
@@ -126,7 +130,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@@ -134,7 +138,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@@ -142,7 +146,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Col,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@@ -150,7 +154,7 @@ class MixtralPolicy(Policy):
target_module=Linear1D_Row,
kwargs={
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
SubModuleReplacementDescription(
@@ -159,7 +163,7 @@ class MixtralPolicy(Policy):
kwargs={
"gather_output": True,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
),
],
@@ -195,7 +199,7 @@ class MixtralPolicy(Policy):
"tp_group": self.shard_config.tensor_parallel_process_group,
"moe_dp_group": self.shard_config.moe_dp_group,
"fp8_communication": self.shard_config.fp8_communication,
"use_zbv": self.shard_config.use_zbv,
"use_zbv": use_zbv,
},
)
],
@@ -330,6 +334,10 @@ class MixtralModelPolicy(MixtralPolicy):
class MixtralForCausalLMPolicy(MixtralPolicy):
def module_policy(self):
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for causal lm
@@ -342,7 +350,7 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
)
],
@@ -392,6 +400,10 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
from transformers import MixtralForSequenceClassification
policy = super().module_policy()
if self.pipeline_stage_manager:
use_zbv = self.pipeline_stage_manager.use_zbv
else:
use_zbv = False
if self.shard_config.enable_tensor_parallelism:
# add a new item for sequence classification
@@ -404,7 +416,7 @@ class MixtralForSequenceClassificationPolicy(MixtralPolicy):
kwargs=dict(
gather_output=True,
fp8_communication=self.shard_config.fp8_communication,
use_zbv=self.shard_config.use_zbv,
use_zbv=use_zbv,
),
)
]

View File

@@ -49,7 +49,6 @@ class ShardConfig:
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
use_zbv: bool = False
# For ring attention
inner_ring_size: Optional[int] = None