remove all to all

This commit is contained in:
GuangyaoZhang
2024-07-17 05:33:38 +00:00
parent 5a310b9ee1
commit 6a20f07b80
5 changed files with 36 additions and 176 deletions

View File

@@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
return ret.to(ret_type)
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
r"""
This is an in-place operation for compressed all_reduce using fp8.
It works like dist.all_reduce but during communication the data is cast to fp8 format.
@@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
del inp["fp8_scale"]
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
r"""
This is an in-place operation for compressed reduce_scatter using fp8.
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.