mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-28 16:28:10 +00:00
[fp8]support all2all fp8 (#5953)
* support all2all fp8 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0c10afd372
commit
afb26de873
@ -115,6 +115,62 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
|||||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_single_fp8(
|
||||||
|
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
This is an in-place operation for compressed all_reduce using fp8.
|
||||||
|
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||||
|
Args:
|
||||||
|
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||||
|
fp8_format: e4m3 or e5m2
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
world_size = dist.get_world_size(group=group)
|
||||||
|
input_type = input.dtype
|
||||||
|
input_shape = input.shape
|
||||||
|
input_device = input.device
|
||||||
|
input = input.flatten()
|
||||||
|
|
||||||
|
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||||
|
|
||||||
|
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||||
|
|
||||||
|
inp = ret.view(torch.uint8)
|
||||||
|
if input_split_sizes is not None:
|
||||||
|
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
|
||||||
|
input_chunks = list(torch.split(inp, input_split_sizes))
|
||||||
|
else:
|
||||||
|
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||||
|
|
||||||
|
if output_split_sizes is not None:
|
||||||
|
output_chunks = [
|
||||||
|
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
|
||||||
|
for i in range(world_size)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if dist.get_rank() == world_size - 1:
|
||||||
|
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||||
|
else:
|
||||||
|
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||||
|
|
||||||
|
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||||
|
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||||
|
dist.all_gather(scale_list, scale, group=group)
|
||||||
|
cast_output_chunk = [
|
||||||
|
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||||
|
outputs_shape = list(input_shape)
|
||||||
|
if output_split_sizes is not None:
|
||||||
|
outputs_shape[0] = sum(output_split_sizes)
|
||||||
|
else:
|
||||||
|
outputs_shape = input_shape
|
||||||
|
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||||
|
67
tests/test_fp8/test_all_to_all_single.py
Normal file
67
tests/test_fp8/test_all_to_all_single.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
from colossalai import launch
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.quantization.fp8 import all_to_all_single_fp8
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("shape", [(4,), (1, 8, 16), (4, 8, 16)])
|
||||||
|
@parameterize("dtype", [torch.bfloat16])
|
||||||
|
def check_all2all(shape, dtype):
|
||||||
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
|
output = torch.empty_like(x)
|
||||||
|
output_fp8 = torch.empty_like(x)
|
||||||
|
dist.all_to_all_single(output, x, group=_get_default_group(), async_op=False)
|
||||||
|
all_to_all_single_fp8(output_fp8, x, group=_get_default_group(), async_op=False)
|
||||||
|
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize("shape", [(8, 8, 16)])
|
||||||
|
@parameterize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def check_all2all_uneven(shape, dtype):
|
||||||
|
x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||||
|
input_split_sizes = [3, 3, 1, 1]
|
||||||
|
if dist.get_rank() in [0, 1]:
|
||||||
|
output_split_sizes = [3, 3, 3, 3]
|
||||||
|
else:
|
||||||
|
output_split_sizes = [1, 1, 1, 1]
|
||||||
|
output_shape = list(shape)
|
||||||
|
output_shape[0] = sum(output_split_sizes)
|
||||||
|
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||||
|
output_fp8 = torch.empty(output_shape, device=x.device, dtype=x.dtype)
|
||||||
|
dist.all_to_all_single(
|
||||||
|
output,
|
||||||
|
x,
|
||||||
|
output_split_sizes=output_split_sizes,
|
||||||
|
input_split_sizes=input_split_sizes,
|
||||||
|
group=_get_default_group(),
|
||||||
|
async_op=False,
|
||||||
|
)
|
||||||
|
all_to_all_single_fp8(
|
||||||
|
output_fp8,
|
||||||
|
x,
|
||||||
|
output_split_sizes=output_split_sizes,
|
||||||
|
input_split_sizes=input_split_sizes,
|
||||||
|
group=_get_default_group(),
|
||||||
|
async_op=False,
|
||||||
|
)
|
||||||
|
assert_close(output, output_fp8, rtol=0.1, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||||
|
check_all2all()
|
||||||
|
check_all2all_uneven()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_all_to_all_single():
|
||||||
|
spawn(run_dist, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_to_all_single()
|
Loading…
Reference in New Issue
Block a user