[fp8]Moe support fp8 communication (#5977)

* fix

* support moe fp8

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* fix

* fix

fix

fi

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-08-09 18:26:02 +08:00
committed by GitHub
parent e4aadeee20
commit f1a3a326c4
8 changed files with 160 additions and 52 deletions

View File

@@ -6,6 +6,8 @@ from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.distributed import ProcessGroup
from colossalai.quantization.fp8 import all_to_all_single_fp8
MOE_KERNEL = None
@@ -380,6 +382,7 @@ def _all_to_all(
output_split_sizes: Optional[List[int]] = None,
group=None,
async_op: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@@ -392,9 +395,14 @@ def _all_to_all(
outputs = torch.empty(outputs_shape, dtype=inputs.dtype, device=inputs.device)
inputs = inputs.contiguous()
outputs = outputs.contiguous()
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
if fp8_communication:
handle = all_to_all_single_fp8(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=False
)
else:
handle = dist.all_to_all_single(
outputs, inputs, output_split_sizes, input_split_sizes, group=group, async_op=async_op
)
return outputs, handle
@@ -407,6 +415,7 @@ class AllToAllUneven(torch.autograd.Function):
output_split_sizes=None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
"""
Returns:
@@ -416,7 +425,9 @@ class AllToAllUneven(torch.autograd.Function):
ctx.input_split_sizes = input_split_sizes
ctx.output_split_sizes = output_split_sizes
ctx.group = group
return _all_to_all(inputs, input_split_sizes, output_split_sizes, group, overlap)
return _all_to_all(
inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication=fp8_communication
)
@staticmethod
def backward(ctx: Any, *grad_outputs):
@@ -426,6 +437,7 @@ class AllToAllUneven(torch.autograd.Function):
None,
None,
None,
None,
)
@@ -435,8 +447,9 @@ def all_to_all_uneven(
output_split_sizes: Optional[List[int]] = None,
group=None,
overlap: bool = False,
fp8_communication: bool = False,
):
assert (
inputs.requires_grad
), "Input must require grad to assure that backward is executed, otherwise it might hang the program."
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap)
return AllToAllUneven.apply(inputs, input_split_sizes, output_split_sizes, group, overlap, fp8_communication)