mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[fp8]Moe support fp8 communication (#5977)
* fix * support moe fp8 * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix fix fi * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,8 @@ from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||
|
||||
MOE_KERNEL = None
|
||||
|
||||
|
||||
@@ -380,6 +382,7 @@ def _all_to_all(
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
async_op: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
@@ -392,9 +395,14 @@ def _all_to_all(
|
||||
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
inputs = inputs.contiguous()
|
||||
outputs = outputs.contiguous()
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
if fp8_communication:
|
||||
handle = all_to_all_single_fp8(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
|
||||
)
|
||||
else:
|
||||
handle = dist.all_to_all_single(
|
||||
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
|
||||
)
|
||||
return outputs, handle
|
||||
|
||||
|
||||
@@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
|
||||
output_split_sizes=None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
@@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
|
||||
ctx.input_split_sizes = input_split_sizes
|
||||
ctx.output_split_sizes = output_split_sizes
|
||||
ctx.group = group
|
||||
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return _all_to_all(
|
||||
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs):
|
||||
@@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@@ -435,8 +447,9 @@ def all_to_all_uneven(
|
||||
output_split_sizes: Optional[List[int]] = None,
|
||||
group=None,
|
||||
overlap: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
):
|
||||
assert (
|
||||
inputs.requires_grad
|
||||
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
|
||||
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)
|
||||
|
Reference in New Issue
Block a user