[fp8] support all-gather flat tensor (#5932)

This commit is contained in:
Hongxin Liu
2024-07-24 16:55:20 +08:00
committed by GitHub
parent 62661cde22
commit 5fd0592767
2 changed files with 116 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
from typing import Any
import numpy as np
import torch
import torch.distributed as dist
@@ -202,3 +203,78 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
out = out.view(fp8_type)
summed_out += cast_from_fp8(out, scale, input_type)
output.data = summed_out
def split_chunk_by_channel(
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
):
offset = chunk.numel() * rank
end = offset + chunk.numel()
break_points = [x for x in range(0, channel_size * num_channels + 1, channel_size) if offset <= x <= end]
if len(break_points) == 0 or break_points[0] > offset:
break_points.insert(0, offset)
if break_points[-1] < end:
break_points.append(end)
sizes = [b - a for a, b in zip(break_points[:-1], break_points[1:])]
return chunk.split(sizes)
def all_gather_into_tensor_flat_fp8(
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
output_shape: torch.Size,
group: dist.ProcessGroup,
fp8_format: str = "e4m3",
):
"""all gather into tensor in fp8 format
Args:
output_tensor (torch.Tensor): output tensor, which is flattened
input_tensor (torch.Tensor): input tensor, which is flattened
group (dist.ProcessGroup): process group
fp8_format (str, optional): fp8 format, e4m3 or e5m2. Defaults to "e4m3".
"""
assert input_tensor.dim() == 1 and output_tensor.dim() == 1, "input/output tensor should be flattened"
world_size = dist.get_world_size(group)
assert (
output_tensor.numel() == input_tensor.numel() * world_size
), "output tensor size should be world_size times of input tensor size"
input_type = output_tensor.dtype
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
fp8_max = torch.finfo(fp8_type).max
if len(output_shape) == 2:
per_channel_max = torch.zeros(output_shape[0], device=output_tensor.device, dtype=torch.float)
num_channels, channel_size = output_shape
rank = dist.get_rank(group)
channel_start_idx = (input_tensor.numel() * rank) // channel_size
per_channel_splits = split_chunk_by_channel(input_tensor, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_max[idx] = per_channel_split.abs().max().float()
dist.all_reduce(per_channel_max, op=dist.ReduceOp.MAX, group=group)
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
scale = fp8_max / per_channel_max
fp8_input = input_tensor.float()
fp8_per_channel_splits = split_chunk_by_channel(fp8_input, channel_size, num_channels, rank, world_size)
for i, per_channel_split in enumerate(fp8_per_channel_splits):
idx = i + channel_start_idx
if idx < num_channels:
per_channel_split.mul_(scale[idx])
fp8_input = fp8_input.to(fp8_type)
else:
per_tensor_max = input_tensor.abs().max().float()
dist.all_reduce(per_tensor_max, op=dist.ReduceOp.MAX, group=group)
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
scale = fp8_max / per_tensor_max
fp8_input = (scale * input_tensor.float()).to(fp8_type)
scale_inv = 1.0 / scale
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
numel = np.prod(output_shape)
valid_buffer = buffer[:numel].reshape(output_shape)
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type)
output_tensor[:numel].copy_(valid_buffer.view(-1))