mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
shardformer fp8
This commit is contained in:
@@ -945,7 +945,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication in model parallelism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1119,6 +1120,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
initial_scale=initial_scale,
|
||||
|
Reference in New Issue
Block a user