mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
remove all to all
This commit is contained in:
@@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e5m2", group=None) -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
@@ -167,7 +167,7 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
del inp["fp8_scale"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||
|
Reference in New Issue
Block a user