mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
fix rebase
This commit is contained in:
@@ -55,7 +55,7 @@ def cast_from_fp8(inp: torch.Tensor, scale_inv: torch.Tensor, ret_type: torch.dt
|
||||
return ret.to(ret_type)
|
||||
|
||||
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3", group=None) -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
@@ -66,7 +66,7 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
None
|
||||
"""
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
world_size = dist.get_world_size(group=group)
|
||||
input_type = tensor.dtype
|
||||
input_shape = tensor.shape
|
||||
input_device = tensor.device
|
||||
@@ -83,19 +83,19 @@ def all_reduce_fp8(tensor: torch.Tensor, fp8_format="e4m3") -> None:
|
||||
output_chunks = [torch.empty_like(input_chunks[-1]) for _ in range(world_size)]
|
||||
else:
|
||||
output_chunks = [torch.empty_like(input_chunks[0]) for _ in range(world_size)]
|
||||
dist.all_to_all(output_chunks, input_chunks)
|
||||
dist.all_to_all(output_chunks, input_chunks, group=group)
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
dist.all_gather(scale_list, scale)
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
tensor_list = list(torch.chunk(torch.empty(input_size, device=input_device, dtype=torch.uint8), world_size, dim=0))
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8))
|
||||
dist.all_gather(tensor_list, summed_out_fp8.view(torch.uint8), group=group)
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
tensor_out = torch.cat(tensor_list, dim=0)
|
||||
@@ -169,8 +169,8 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e4m3") -> None:
|
||||
r"""
|
||||
This is an in-place operation for compressed all_reduce using fp8.
|
||||
It works like dist.all_reduce but during communication the data is cast to fp8 format.
|
||||
This is an in-place operation for compressed reduce_scatter using fp8.
|
||||
It works like dist.reduce_scatter but during communication the data is cast to fp8 format.
|
||||
|
||||
Args:
|
||||
tensor: torch.Tensor in fp32, fp16, bf16 datatype.
|
||||
|
Reference in New Issue
Block a user