mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928)
* support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
||||
scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling
|
||||
is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied.
|
||||
fp8_format: e4m3 or e5m2
|
||||
|
||||
Returns:
|
||||
Tuples: A tuple (fp8_tensor, scale)
|
||||
"""
|
||||
@@ -29,12 +30,13 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -
|
||||
per_channel_max = inp.abs().max(dim=-1).values.float()
|
||||
per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0)
|
||||
scale = fp8_max / per_channel_max[:, None]
|
||||
scale_inv = per_channel_max / fp8_max
|
||||
else:
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
scale = fp8_max / per_tensor_max
|
||||
scale_inv = 1.0 / scale
|
||||
|
||||
scale_inv = 1.0 / scale
|
||||
ret = (scale * inp.float()).to(fp8_type)
|
||||
return ret, scale_inv
|
||||
|
||||
@@ -185,7 +187,11 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
return
|
||||
|
||||
assert "hidden_states" in inp, "required by pipeline parallelism."
|
||||
assert (
|
||||
inp["hidden_states"].size(-1) % 2 == 0
|
||||
), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16"
|
||||
inp_tensor = inp["hidden_states"]
|
||||
inp_dtype = inp_tensor.dtype
|
||||
|
||||
min_val, max_val = inp_tensor.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs())
|
||||
@@ -206,6 +212,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None:
|
||||
inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type)
|
||||
|
||||
inp["fp8_scale"] = scale.float().reciprocal()
|
||||
inp["dtype"] = torch.zeros_like(scale).to(inp_dtype)
|
||||
|
||||
|
||||
def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
@@ -230,10 +237,11 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None:
|
||||
else:
|
||||
raise TypeError("Only float16, bfloat16 are implemented.")
|
||||
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale
|
||||
inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale
|
||||
|
||||
if del_metadata:
|
||||
del inp["fp8_scale"]
|
||||
del inp["dtype"]
|
||||
|
||||
|
||||
def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None:
|
||||
@@ -273,6 +281,199 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2
|
||||
output.data = summed_out
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_async(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format: str = "e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size.
|
||||
|
||||
This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it
|
||||
by the process group size.
|
||||
Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back
|
||||
to the input data type (such as ``float32``).
|
||||
|
||||
Example::
|
||||
>>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async)
|
||||
"""
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
|
||||
input_tensor = bucket.buffer()
|
||||
world_size = dist.get_world_size()
|
||||
input_type = input_tensor.dtype
|
||||
input_device = input_tensor.device
|
||||
flat_padded_x = input_tensor.flatten()
|
||||
|
||||
if flat_padded_x.size(0) % world_size != 0:
|
||||
pad_size = world_size - flat_padded_x.size(0) % world_size
|
||||
flat_padded_x = F.pad(flat_padded_x, (0, pad_size))
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format)
|
||||
|
||||
inp = ret.view(torch.uint8)
|
||||
output_chunks_single = torch.empty_like(inp)
|
||||
split_sizes = [inp.numel() // world_size for _ in range(world_size)]
|
||||
fut0 = dist.all_to_all_single(
|
||||
output_chunks_single,
|
||||
inp,
|
||||
output_split_sizes=split_sizes,
|
||||
input_split_sizes=split_sizes,
|
||||
group=group_to_use,
|
||||
async_op=True,
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut1 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
all_to_all_fut = torch.futures.collect_all([fut0, fut1])
|
||||
|
||||
def sum_and_allgather(fut):
|
||||
output_chunks_single = fut.value()[0].wait()[0]
|
||||
scale_list_single = fut.value()[1].wait()[0]
|
||||
|
||||
output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
summed_out = torch.zeros_like(output_chunks[0]).to(input_type)
|
||||
for scale, out in zip(scale_list, output_chunks):
|
||||
out = out.view(fp8_type)
|
||||
summed_out += cast_from_fp8(out, scale, input_type)
|
||||
summed_out.div_(world_size)
|
||||
|
||||
summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format)
|
||||
|
||||
tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8)
|
||||
fut2 = dist.all_gather_into_tensor(
|
||||
tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
fut3 = dist.all_gather_into_tensor(
|
||||
torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
fut_combined2 = torch.futures.collect_all([fut2, fut3])
|
||||
return fut_combined2
|
||||
|
||||
def decompress(fut):
|
||||
tensor_list_single = fut.value().wait()[0].value()[0]
|
||||
scale_list_single = fut.value().wait()[1].value()[0]
|
||||
|
||||
tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0))
|
||||
scale_list = scale_list_single.chunk(world_size, dim=0)
|
||||
|
||||
for i in range(world_size):
|
||||
tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i]
|
||||
out = torch.cat(tensor_list, dim=0)
|
||||
|
||||
input_tensor_size = input_tensor.numel()
|
||||
input_shape = input_tensor.shape
|
||||
out = out[:input_tensor_size]
|
||||
|
||||
input_tensor.copy_(out.view(input_shape).to(input_type))
|
||||
return input_tensor
|
||||
|
||||
return all_to_all_fut.then(sum_and_allgather).then(decompress)
|
||||
|
||||
|
||||
def fp8_compress_ddp_grad_comm_hook_sync(
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
fp8_format="e5m2",
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized.
|
||||
This breaks the overlapping between allreduce communication and backward compuation.
|
||||
|
||||
This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization.
|
||||
For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead.
|
||||
|
||||
Example::
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync)
|
||||
"""
|
||||
|
||||
buffer = bucket.buffer()
|
||||
all_reduce_fp8(buffer, fp8_format=fp8_format)
|
||||
|
||||
fut: torch.futures.Future[torch.Tensor] = torch.futures.Future()
|
||||
fut.set_result(bucket.buffer())
|
||||
|
||||
return fut
|
||||
|
||||
|
||||
def fp8_compress_fsdp_grad_comm_hook(
|
||||
state: object,
|
||||
unsharded_gradient_flattened: torch.Tensor,
|
||||
sharded_gradient: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor
|
||||
to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic
|
||||
by using all_to_all and all_gather among the process group.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook)
|
||||
"""
|
||||
grad = unsharded_gradient_flattened
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
input_type = grad.dtype
|
||||
input_device = grad.device
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format)
|
||||
uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8)
|
||||
dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group)
|
||||
|
||||
scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)]
|
||||
dist.all_gather(scale_list, scale, group=group)
|
||||
|
||||
buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0))
|
||||
sharded_gradient.zero_()
|
||||
for tensor, scale in zip(buffer_list, scale_list):
|
||||
sharded_gradient += cast_from_fp8(tensor, scale, input_type)
|
||||
|
||||
|
||||
def fp8_compress_fsdp_params_comm_hook(
|
||||
state: object,
|
||||
padded_unsharded_flat_param: torch.Tensor,
|
||||
sharded_flat_param: torch.Tensor,
|
||||
group=None,
|
||||
fp8_format="e5m2",
|
||||
) -> None:
|
||||
"""
|
||||
This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook.
|
||||
|
||||
Example::
|
||||
>>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook)
|
||||
"""
|
||||
|
||||
fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2
|
||||
fp8_max = torch.finfo(fp8_type).max
|
||||
inp = sharded_flat_param
|
||||
out = padded_unsharded_flat_param
|
||||
|
||||
per_tensor_max = inp.abs().max().float()
|
||||
per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0)
|
||||
dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group)
|
||||
|
||||
scale = fp8_max / per_tensor_max
|
||||
fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8)
|
||||
|
||||
fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device)
|
||||
dist.all_gather_into_tensor(
|
||||
fp8_out,
|
||||
fp8_sharded_flat_param,
|
||||
group=group,
|
||||
)
|
||||
padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype))
|
||||
|
||||
|
||||
def split_chunk_by_channel(
|
||||
chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1
|
||||
):
|
||||
@@ -342,7 +543,7 @@ def all_gather_into_tensor_flat_fp8(
|
||||
scale_inv = 1.0 / scale
|
||||
buffer = torch.empty_like(output_tensor, dtype=fp8_type)
|
||||
dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group)
|
||||
numel = np.prod(output_shape)
|
||||
numel = output_shape.numel()
|
||||
valid_buffer = buffer[:numel].reshape(output_shape)
|
||||
valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2))
|
||||
output_tensor[:numel].copy_(valid_buffer.view(-1))
|
||||
|
Reference in New Issue
Block a user