mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[fp8] support asynchronous FP8 communication (#5997)
* fix * fix * fix * support async all2all * support async op for all gather * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,19 +10,24 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16])
|
||||
def check_all2all(shape, dtype):
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_all2all(shape, dtype, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
output = torch.empty_like(x)
|
||||
output_fp8 = torch.empty_like(x)
|
||||
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False)
|
||||
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False)
|
||||
origin_hanle = dist.all_to_all_single(output, x, group=_get_default_group(), async_op=async_op)
|
||||
fp8_handle = all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=async_op)
|
||||
if async_op:
|
||||
origin_hanle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
@parameterize("shape", [(8, 8, 16)])
|
||||
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||
def check_all2all_uneven(shape, dtype):
|
||||
@parameterize("async_op", [True, False])
|
||||
def check_all2all_uneven(shape, dtype, async_op):
|
||||
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_split_sizes = [3, 3, 1, 1]
|
||||
if dist.get_rank() in [0, 1]:
|
||||
@@ -33,22 +38,25 @@ def check_all2all_uneven(shape, dtype):
|
||||
output_shape[0] = sum(output_split_sizes)
|
||||
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||
dist.all_to_all_single(
|
||||
origin_hanle = dist.all_to_all_single(
|
||||
output,
|
||||
x,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=_get_default_group(),
|
||||
async_op=False,
|
||||
async_op=async_op,
|
||||
)
|
||||
all_to_all_single_fp8(
|
||||
fp8_handle = all_to_all_single_fp8(
|
||||
output_fp8,
|
||||
x,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=_get_default_group(),
|
||||
async_op=False,
|
||||
async_op=async_op,
|
||||
)
|
||||
if async_op:
|
||||
origin_hanle.wait()
|
||||
fp8_handle.wait()
|
||||
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user