mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[fp8]support all2all fp8 (#5953)
* support all2all fp8 * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -115,6 +115,62 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = input.dtype
|
||||
input_shape = input.shape
|
||||
input_device = input.device
|
||||
input = input.flatten()
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
if input_split_sizes is not None:
|
||||
input_split_sizes = [input_split_sizes[i] * np.prod(input_shape[1:]) for i in range(world_size)]
|
||||
input_chunks = list(torch.split(inp, input_split_sizes))
|
||||
else:
|
||||
input_chunks = list(torch.chunk(inp, world_size, dim=0))
|
||||
|
||||
if output_split_sizes is not None:
|
||||
output_chunks = [
|
||||
torch.empty((output_split_sizes[i] * np.prod(input_shape[1:]),), device=input_device, dtype=inp.dtype)
|
||||
for i in range(world_size)
|
||||
]
|
||||
else:
|
||||
if dist.get_rank() == world_size - 1:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
cast_output_chunk = [
|
||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||
]
|
||||
|
||||
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||
outputs_shape = list(input_shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
else:
|
||||
outputs_shape = input_shape
|
||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
"""
|
||||
Cast the hidden_states tensor of inp object to fp8 format before p2p communication in pipeline.
|
||||
|
Reference in New Issue
Block a user