mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -274,6 +274,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
fp8_communication: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -282,6 +283,7 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.process_group = process_group
|
||||
self.fp8_communication = fp8_communication
|
||||
|
||||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
@@ -390,5 +392,5 @@ class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_forward(embedding_output, self.process_group)
|
||||
output = reduce_forward(embedding_output, self.process_group, fp8_communication=self.fp8_communication)
|
||||
return output
|
||||
|
Reference in New Issue
Block a user