shardformer fp8

This commit is contained in:
GuangyaoZhang
2024-07-08 07:04:48 +00:00
parent 51f916b11d
commit 457a0de79f
16 changed files with 520 additions and 234 deletions

View File

@@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
"""
def __init__(
@@ -1119,6 +1120,7 @@ class HybridParallelPlugin(PipelinePluginBase):
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
fp8_communication=fp8_communication,
)
self.amp_config = dict(
initial_scale=initial_scale,