mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[fp8] support asynchronous FP8 communication (#5997)
* fix * fix * fix * support async all2all * support async op for all gather * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,6 +10,18 @@ from torch.distributed import ReduceOp
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
|
||||
|
||||
|
||||
class Handle:
|
||||
def __init__(self, handles=[], remain_ops=None) -> None:
|
||||
self.handles = handles
|
||||
self.remain_ops = remain_ops
|
||||
|
||||
def wait(self):
|
||||
for handle in self.handles:
|
||||
handle.wait()
|
||||
if self.remain_ops:
|
||||
self.remain_ops()
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
@@ -68,7 +80,9 @@ def cast_from_fp8(
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None) -> None:
|
||||
def all_reduce_fp8(
|
||||
tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, group=None, async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
@@ -105,6 +119,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
@@ -113,19 +128,28 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", op=ReduceOp.SUM, gro
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
gather_scale_handle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
tensor_list = [torch.empty_like(summed_out_fp8.view(torch.uint8)) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i].to(input_device)
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
gather_tensor_handle = dist.all_gather(
|
||||
tensor_list, summed_out_fp8.view(torch.uint8), group=group, async_op=async_op
|
||||
)
|
||||
|
||||
def cat_op():
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
tensor.copy_(out[:input_size].view(input_shape).to(input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([gather_scale_handle, gather_tensor_handle], cat_op)
|
||||
else:
|
||||
cat_op()
|
||||
|
||||
|
||||
def all_to_all_single_fp8(
|
||||
output, input, output_split_sizes=None, input_split_sizes=None, fp8_format="e5m2", group=None, async_op=False
|
||||
) -> None:
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_to_all_single but during communication the data is cast to fp8 format.
|
||||
@@ -163,20 +187,27 @@ def all_to_all_single_fp8(
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
chunk_handle = dist.all_to_all(output_chunks, input_chunks, group=group, async_op=async_op)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
cast_output_chunk = [
|
||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||
]
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||
outputs_shape = list(input_shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
def cast_op():
|
||||
cast_output_chunk = [
|
||||
cast_from_fp8(out.view(fp8_type), scale, input_type) for scale, out in zip(scale_list, output_chunks)
|
||||
]
|
||||
|
||||
tensor_out = torch.cat(cast_output_chunk, dim=0)
|
||||
outputs_shape = list(input_shape)
|
||||
if output_split_sizes is not None:
|
||||
outputs_shape[0] = sum(output_split_sizes)
|
||||
else:
|
||||
outputs_shape = input_shape
|
||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
outputs_shape = input_shape
|
||||
output.data = tensor_out.view(outputs_shape).to(input_type)
|
||||
cast_op()
|
||||
|
||||
|
||||
def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
@@ -250,7 +281,9 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
del inp["dtype"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
|
||||
def reduce_scatter_fp8(
|
||||
output: torch.Tensor, input_list, group, fp8_format="e5m2", async_op: bool = False
|
||||
) -> Optional[Handle]:
|
||||
r"""
|
||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||
@@ -277,14 +310,20 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
|
||||
cast_input_list.append(ret)
|
||||
output_chunks.append(torch.empty_like(ret))
|
||||
output_scale_list.append(torch.empty_like(scale))
|
||||
dist.all_to_all(output_chunks, cast_input_list, group=group)
|
||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
||||
chunk_handle = dist.all_to_all(output_chunks, cast_input_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(output_scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
output.data = summed_out
|
||||
def cast_op():
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(output_scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
output.data = summed_out
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_async(
|
||||
@@ -500,7 +539,8 @@ def all_gather_into_tensor_flat_fp8(
|
||||
output_shape: torch.Size,
|
||||
group: dist.ProcessGroup,
|
||||
fp8_format: str = "e4m3",
|
||||
):
|
||||
async_op: bool = False,
|
||||
) -> Optional[Handle]:
|
||||
"""all gather into tensor in fp8 format
|
||||
|
||||
Args:
|
||||
@@ -547,15 +587,25 @@ def all_gather_into_tensor_flat_fp8(
|
||||
scale = fp8_max / per_tensor_max
|
||||
fp8_input = (scale * input_tensor.float()).to(fp8_type)
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
||||
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
||||
numel = output_shape.numel()
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
tensor_handle = dist.all_gather_into_tensor(
|
||||
buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group, async_op=async_op
|
||||
)
|
||||
|
||||
def cast_op():
|
||||
numel = output_shape.numel()
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
|
||||
if async_op:
|
||||
return Handle([tensor_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||
def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2", async_op=False):
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
@@ -573,17 +623,23 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"):
|
||||
|
||||
output_scale_list = [torch.empty_like(x) for x in scale_list]
|
||||
output_tensor_list = [torch.empty_like(x) for x in tensor_list]
|
||||
dist.all_to_all(output_tensor_list, tensor_list, group=group)
|
||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
||||
tensor_hanle = dist.all_to_all(output_tensor_list, tensor_list, group=group, async_op=async_op)
|
||||
scale_handle = dist.all_to_all(output_scale_list, scale_list, group=group, async_op=async_op)
|
||||
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
scale = output_scale_list[i]
|
||||
tensor = output_tensor_list[i]
|
||||
tensor = tensor.view(fp8_type)
|
||||
output_list[i].copy_(cast_from_fp8(tensor, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([tensor_hanle, scale_handle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||
def gather_fp8(output_list, input_, group=None, fp8_format="e5m2", async_op: bool = False) -> Optional[Handle]:
|
||||
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
@@ -593,13 +649,19 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||
input_ = ret.view(torch.uint8)
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_.device) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, input_, group=group)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
chunk_handle = dist.all_gather(tensor_list, input_, group=group, async_op=async_op)
|
||||
scale_hanle = dist.all_gather(scale_list, scale, group=group, async_op=async_op)
|
||||
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
def cast_op():
|
||||
for i in range(world_size):
|
||||
output = tensor_list[i].view(fp8_type)
|
||||
scale = scale_list[i]
|
||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||
|
||||
if async_op:
|
||||
return Handle([chunk_handle, scale_hanle], cast_op)
|
||||
else:
|
||||
cast_op()
|
||||
|
||||
|
||||
class _LinearFp8(torch.autograd.Function):
|
||||
|
Reference in New Issue
Block a user