mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
shardformer fp8
This commit is contained in:
@@ -12,7 +12,6 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
||||
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
||||
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
@@ -39,12 +38,10 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3") -> (torch.Tensor, torch.Te
|
||||
|
||||
def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dtype) -> torch.Tensor:
|
||||
r"""
|
||||
|
||||
Args:
|
||||
inp: should be a fp8 torch tensor in one of the types: [torch.float8_e4m3fn, torch.float8_e5m2].
|
||||
scale: scaling factor returned by cast_to_fp8 function.
|
||||
ret_type: the datatype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
"""
|
||||
@@ -62,11 +59,9 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
@@ -170,3 +165,40 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
input_type = output.dtype
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
scale_list = []
|
||||
cast_input_list = []
|
||||
output_chunks = []
|
||||
output_scale_list = []
|
||||
for input in input_list:
|
||||
ret, scale = cast_to_fp8(input, fp8_format=fp8_format)
|
||||
scale_list.append(scale)
|
||||
ret = ret.view(torch.uint8)
|
||||
cast_input_list.append(ret)
|
||||
output_chunks.append(torch.empty_like(ret))
|
||||
output_scale_list.append(torch.empty_like(scale))
|
||||
dist.all_to_all(output_chunks, cast_input_list, group=group)
|
||||
dist.all_to_all(output_scale_list, scale_list, group=group)
|
||||
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(output_scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
output.data = summed_out
|
||||
|
Reference in New Issue
Block a user