From b480eec738b58bd40a6e4cc656c8439ffed0be98 Mon Sep 17 00:00:00 2001 From: Hanks Date: Thu, 8 Aug 2024 15:55:01 +0800 Subject: [PATCH] [Feature]: support FP8 communication in DDP, FSDP, Gemini (#5928) * support fp8_communication in the Torch DDP grad comm, FSDP grad comm, and FSDP params comm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement communication hook for FSDP params all-gather * added unit test for fp8 operators * support fp8 communication in GeminiPlugin * update training scripts to support fsdp and fp8 communication * fixed some minor bugs observed in unit test * add all_gather_into_tensor_flat_fp8 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add skip the test if torch < 2.2.0 * add skip the test if torch < 2.2.0 * add fp8_comm flag * rebase latest fp8 operators * rebase latest fp8 operators * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/booster/plugin/gemini_plugin.py | 2 + colossalai/booster/plugin/torch_ddp_plugin.py | 7 + .../booster/plugin/torch_fsdp_plugin.py | 15 ++ colossalai/quantization/fp8.py | 207 +++++++++++++++++- colossalai/quantization/utils.py | 112 ++++++++++ colossalai/zero/gemini/chunk/chunk.py | 15 +- colossalai/zero/gemini/chunk/manager.py | 4 + colossalai/zero/gemini/gemini_ddp.py | 3 + examples/language/bert/finetune.py | 17 +- .../gpt/hybridparallelism/finetune.py | 2 +- examples/language/llama/benchmark.py | 12 +- tests/test_fp8/test_fp8_cast.py | 26 +++ tests/test_fp8/test_fp8_ddp_comm_hook.py | 87 ++++++++ tests/test_fp8/test_fp8_fsdp_comm_hook.py | 107 +++++++++ 14 files changed, 602 insertions(+), 14 deletions(-) create mode 100644 colossalai/quantization/utils.py create mode 100644 tests/test_fp8/test_fp8_cast.py create mode 100644 tests/test_fp8/test_fp8_ddp_comm_hook.py create mode 100644 tests/test_fp8/test_fp8_fsdp_comm_hook.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ad131fbe7..5ab8f05ad 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -364,6 +364,7 @@ class GeminiPlugin(DPPluginBase): enable_sequence_overlap: bool = False, enable_async_reduce: bool = True, verbose: bool = False, + fp8_communication: bool = False, ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -395,6 +396,7 @@ class GeminiPlugin(DPPluginBase): master_weights=master_weights, max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, + fp8_communication=fp8_communication, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 5116446a4..34caa2f68 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -177,6 +177,7 @@ class TorchDDPPlugin(DPPluginBase): check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False, + fp8_communication: bool = False, ) -> None: super().__init__() self.ddp_kwargs = dict( @@ -187,6 +188,7 @@ class TorchDDPPlugin(DPPluginBase): gradient_as_bucket_view=gradient_as_bucket_view, static_graph=static_graph, ) + self.fp8_communication = fp8_communication def support_no_sync(self) -> bool: return True @@ -226,6 +228,11 @@ class TorchDDPPlugin(DPPluginBase): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = OptimizerWrapper(optimizer) + if self.fp8_communication: + from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async + + model.module.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_async) + return model, optimizer, criterion, dataloader, lr_scheduler def control_checkpoint_io(self) -> bool: diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index cd2f9e840..e3f81928d 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -298,6 +298,7 @@ class TorchFSDPPlugin(DPPluginBase): ignored_modules: Optional[Iterable[torch.nn.Module]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, sync_module_states: bool = False, + fp8_communication: bool = False, ): super().__init__() self.fsdp_kwargs = dict( @@ -311,6 +312,7 @@ class TorchFSDPPlugin(DPPluginBase): param_init_fn=param_init_fn, sync_module_states=sync_module_states, ) + self.fp8_communication = fp8_communication else: raise RuntimeError("FSDP is not supported while torch version under 1.12.0.") @@ -347,6 +349,19 @@ class TorchFSDPPlugin(DPPluginBase): # wrap the model with PyTorch FSDP fsdp_model = TorchFSDPModel(model, device_id=torch.cuda.current_device(), **self.fsdp_kwargs) + if self.fp8_communication: + from colossalai.quantization.utils import patch_fsdp_params_comm_hook + + patch_fsdp_params_comm_hook() + + from colossalai.quantization.fp8 import fp8_compress_fsdp_params_comm_hook + + fsdp_model.module.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook) + + from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook + + fsdp_model.module.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook) + if optimizer is not None: if len(optimizer.param_groups) > 1: warnings.warn( diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 52bb8cc9b..e933680a9 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -15,6 +15,7 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - scale: scaling factor for fp8 casting. If it is None, then it is computed automatically. Per-channel scaling is applied if input tensor is 2 dimension, otherwise, per-tensor scaling is applied. fp8_format: e4m3 or e5m2 + Returns: Tuples: A tuple (fp8_tensor, scale) """ @@ -29,12 +30,13 @@ def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) - per_channel_max = inp.abs().max(dim=-1).values.float() per_channel_max = torch.where(per_channel_max > 0, per_channel_max, 1.0) scale = fp8_max / per_channel_max[:, None] + scale_inv = per_channel_max / fp8_max else: per_tensor_max = inp.abs().max().float() per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) scale = fp8_max / per_tensor_max + scale_inv = 1.0 / scale - scale_inv = 1.0 / scale ret = (scale * inp.float()).to(fp8_type) return ret, scale_inv @@ -185,7 +187,11 @@ def cast_to_fp8_pipeline(inp: Any) -> None: return assert "hidden_states" in inp, "required by pipeline parallelism." + assert ( + inp["hidden_states"].size(-1) % 2 == 0 + ), "tensor size(-1) must be divisible by 2 to view Float8_e4m3fn as BFloat16 or Float16" inp_tensor = inp["hidden_states"] + inp_dtype = inp_tensor.dtype min_val, max_val = inp_tensor.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()) @@ -206,6 +212,7 @@ def cast_to_fp8_pipeline(inp: Any) -> None: inp_tensor.data = q_tensor.to(fp8_type).view(fp8_view_type) inp["fp8_scale"] = scale.float().reciprocal() + inp["dtype"] = torch.zeros_like(scale).to(inp_dtype) def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: @@ -230,10 +237,11 @@ def cast_from_fp8_pipeline(inp: Any, del_metadata=True) -> None: else: raise TypeError("Only float16, bfloat16 are implemented.") - inp_tensor.data = inp_tensor.data.view(fp8_type).to(torch.float16) * scale + inp_tensor.data = inp_tensor.data.view(fp8_type).to(inp["dtype"]) * scale if del_metadata: del inp["fp8_scale"] + del inp["dtype"] def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2") -> None: @@ -273,6 +281,199 @@ def reduce_scatter_fp8(output: torch.Tensor, input_list, group, fp8_format="e5m2 output.data = summed_out +def fp8_compress_ddp_grad_comm_hook_async( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, + fp8_format: str = "e5m2", +) -> torch.futures.Future[torch.Tensor]: + """ + Compress by casting ``GradBucket`` to FP8 floating-point format divided by process group size. + + This DDP communication hook implements a simple gradient compression approach that casts ``GradBucket`` tensor + to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then divides it + by the process group size. + Once compressed gradient tensors are allreduced, the chained callback ``decompress`` casts it back + to the input data type (such as ``float32``). + + Example:: + >>> ddp_model.register_comm_hook(process_group, fp8_compress_ddp_grad_comm_hook_async) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + + input_tensor = bucket.buffer() + world_size = dist.get_world_size() + input_type = input_tensor.dtype + input_device = input_tensor.device + flat_padded_x = input_tensor.flatten() + + if flat_padded_x.size(0) % world_size != 0: + pad_size = world_size - flat_padded_x.size(0) % world_size + flat_padded_x = F.pad(flat_padded_x, (0, pad_size)) + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + ret, scale = cast_to_fp8(flat_padded_x, fp8_format=fp8_format) + + inp = ret.view(torch.uint8) + output_chunks_single = torch.empty_like(inp) + split_sizes = [inp.numel() // world_size for _ in range(world_size)] + fut0 = dist.all_to_all_single( + output_chunks_single, + inp, + output_split_sizes=split_sizes, + input_split_sizes=split_sizes, + group=group_to_use, + async_op=True, + ).get_future() + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + fut1 = dist.all_gather_into_tensor( + torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True + ).get_future() + all_to_all_fut = torch.futures.collect_all([fut0, fut1]) + + def sum_and_allgather(fut): + output_chunks_single = fut.value()[0].wait()[0] + scale_list_single = fut.value()[1].wait()[0] + + output_chunks = list(torch.chunk(output_chunks_single, world_size, dim=0)) + scale_list = scale_list_single.chunk(world_size, dim=0) + + summed_out = torch.zeros_like(output_chunks[0]).to(input_type) + for scale, out in zip(scale_list, output_chunks): + out = out.view(fp8_type) + summed_out += cast_from_fp8(out, scale, input_type) + summed_out.div_(world_size) + + summed_out_fp8, scale = cast_to_fp8(summed_out, fp8_format=fp8_format) + + tensor_list_single = torch.empty(summed_out_fp8.size(0) * world_size, device=input_device, dtype=torch.uint8) + fut2 = dist.all_gather_into_tensor( + tensor_list_single, summed_out_fp8.view(torch.uint8), group=group_to_use, async_op=True + ).get_future() + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + fut3 = dist.all_gather_into_tensor( + torch.cat(scale_list, dim=0), scale, group=group_to_use, async_op=True + ).get_future() + fut_combined2 = torch.futures.collect_all([fut2, fut3]) + return fut_combined2 + + def decompress(fut): + tensor_list_single = fut.value().wait()[0].value()[0] + scale_list_single = fut.value().wait()[1].value()[0] + + tensor_list = list(torch.chunk(tensor_list_single, world_size, dim=0)) + scale_list = scale_list_single.chunk(world_size, dim=0) + + for i in range(world_size): + tensor_list[i] = tensor_list[i].view(fp8_type).to(input_type) * scale_list[i] + out = torch.cat(tensor_list, dim=0) + + input_tensor_size = input_tensor.numel() + input_shape = input_tensor.shape + out = out[:input_tensor_size] + + input_tensor.copy_(out.view(input_shape).to(input_type)) + return input_tensor + + return all_to_all_fut.then(sum_and_allgather).then(decompress) + + +def fp8_compress_ddp_grad_comm_hook_sync( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, + fp8_format="e5m2", +) -> torch.futures.Future[torch.Tensor]: + """ + Return a future that wraps the input, after the input is allreduced. However, the allreduce commnunication is synchronized. + This breaks the overlapping between allreduce communication and backward compuation. + + This hook should **only** be used for debugging purposes, instead of the normal gradient synchronization. + For asynchronized implementation, use fp8_compress_ddp_grad_comm_hook_async instead. + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(None, fp8_compress_ddp_grad_comm_hook_sync) + """ + + buffer = bucket.buffer() + all_reduce_fp8(buffer, fp8_format=fp8_format) + + fut: torch.futures.Future[torch.Tensor] = torch.futures.Future() + fut.set_result(bucket.buffer()) + + return fut + + +def fp8_compress_fsdp_grad_comm_hook( + state: object, + unsharded_gradient_flattened: torch.Tensor, + sharded_gradient: torch.Tensor, + group=None, + fp8_format="e5m2", +) -> None: + """ + This communication hook implements a simple gradient compression approach that casts unsharded_gradient_flattened tensor + to FP8 floating-point format (``torch.float8_e5m2`` or ``torch.bfloat16_e4m3``), and then perform scatter_allreduce logic + by using all_to_all and all_gather among the process group. + + Example:: + >>> fsdp_model.register_comm_hook(None, fp8_compress_fsdp_grad_comm_hook) + """ + grad = unsharded_gradient_flattened + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + input_type = grad.dtype + input_device = grad.device + world_size = dist.get_world_size(group=group) + + grad_fp8, scale = cast_to_fp8(grad, fp8_format=fp8_format) + uint8_buffer = torch.empty_like(grad_fp8).view(torch.uint8) + dist.all_to_all_single(uint8_buffer, grad_fp8.view(torch.uint8), group=group) + + scale_list = [torch.ones(1, dtype=scale.dtype, device=input_device) for _ in range(world_size)] + dist.all_gather(scale_list, scale, group=group) + + buffer_list = list(torch.chunk(uint8_buffer.view(fp8_type), world_size, dim=0)) + sharded_gradient.zero_() + for tensor, scale in zip(buffer_list, scale_list): + sharded_gradient += cast_from_fp8(tensor, scale, input_type) + + +def fp8_compress_fsdp_params_comm_hook( + state: object, + padded_unsharded_flat_param: torch.Tensor, + sharded_flat_param: torch.Tensor, + group=None, + fp8_format="e5m2", +) -> None: + """ + This hook is pending the official support for parameters communication hook in FSDP, e.g. register_params_comm_hook. + + Example:: + >>> fsdp_model.register_params_comm_hook(None, fp8_compress_fsdp_params_comm_hook) + """ + + fp8_type = torch.float8_e4m3fn if fp8_format == "e4m3" else torch.float8_e5m2 + fp8_max = torch.finfo(fp8_type).max + inp = sharded_flat_param + out = padded_unsharded_flat_param + + per_tensor_max = inp.abs().max().float() + per_tensor_max = torch.where(per_tensor_max > 0, per_tensor_max, 1.0) + dist.all_reduce(per_tensor_max, op=torch.distributed.ReduceOp.MAX, group=group) + + scale = fp8_max / per_tensor_max + fp8_sharded_flat_param = (scale * inp.float()).to(fp8_type).view(torch.uint8) + + fp8_out = torch.empty(out.shape, dtype=torch.uint8, device=out.device) + dist.all_gather_into_tensor( + fp8_out, + fp8_sharded_flat_param, + group=group, + ) + padded_unsharded_flat_param.copy_((fp8_out.view(fp8_type).float() / scale).to(out.dtype)) + + def split_chunk_by_channel( chunk: torch.Tensor, channel_size: int, num_channels: int, rank: int = 0, world_size: int = 1 ): @@ -342,7 +543,7 @@ def all_gather_into_tensor_flat_fp8( scale_inv = 1.0 / scale buffer = torch.empty_like(output_tensor, dtype=fp8_type) dist.all_gather_into_tensor(buffer.view(torch.uint8), fp8_input.view(torch.uint8), group=group) - numel = np.prod(output_shape) + numel = output_shape.numel() valid_buffer = buffer[:numel].reshape(output_shape) valid_buffer = cast_from_fp8(valid_buffer, scale_inv, input_type, per_channel_scale=(len(output_shape) == 2)) output_tensor[:numel].copy_(valid_buffer.view(-1)) diff --git a/colossalai/quantization/utils.py b/colossalai/quantization/utils.py new file mode 100644 index 000000000..5b1e11c9f --- /dev/null +++ b/colossalai/quantization/utils.py @@ -0,0 +1,112 @@ +import torch +import torch.distributed as dist +from packaging import version +from torch import Tensor +from torch.distributed.fsdp._common_utils import _no_dispatch_record_stream +from torch.distributed.utils import _p_assert + + +def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, +) -> Tensor: + """ + All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. + + Then switch to use the all-gathered tensor. + """ + _p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + _p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + + pg = self._fake_process_group if self._use_fake_all_gather else self.process_group + + # HACK this should be handled by C10D + if sharded_flat_param.is_cpu: # type: ignore[attr-defined] + tensor_list = list(torch.chunk(padded_unsharded_flat_param, dist.get_world_size(pg))) + work = dist.all_gather(tensor_list, sharded_flat_param, group=pg) + else: + if self._comm_hook is None: + dist.all_gather_into_tensor( + padded_unsharded_flat_param, + sharded_flat_param, + pg, + ) + else: + self._comm_hook(None, padded_unsharded_flat_param, sharded_flat_param, pg) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) + return padded_unsharded_flat_param + + +def register_params_comm_hook(self, state: object, hook: callable): + """Register a communication hook for FlatParamHandle. + + This is an enhancement that provides a flexible hook to users where they can specify how FSDP unshards + parameters across multiple workers. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError("register_comm_hook can only be called on a root instance.") + + # if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + # raise AssertionError( + # f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" + # ) + if self._handle._comm_hook is not None: + raise AssertionError("A communication hook is already registered") + if not callable(hook): + raise ValueError(f"The communication hook must be callable but got {hook}") + self._handle._comm_hook = hook + self._handle._comm_hook_state = state + + +def patch_fsdp_params_comm_hook(): + if version.parse(torch.__version__) >= version.parse("2.2.0"): + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp._flat_param import FlatParamHandle + + FlatParamHandle._comm_hook = None + FlatParamHandle._comm_hook_state = None + FlatParamHandle._all_gather_flat_param = _all_gather_flat_param + FSDP.register_params_comm_hook = register_params_comm_hook + else: + raise RuntimeError("This fsdp_params_comm_hook patch is not supported while torch version under 2.2.0.") diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 969df9621..e2b7a8f56 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -166,6 +166,7 @@ class Chunk: self.grad_chunk = None # the async all-reduce/reduce-scatter work of this grad chunk (None means sync) self.grad_reduce_work = None + self.fp8_communication = False @property def memory_usage(self) -> Dict[str, int]: @@ -521,9 +522,17 @@ class Chunk: alloc_storage(self.cuda_global_chunk) assert self.cuda_global_chunk.is_contiguous() - work = dist.all_gather_into_tensor( - self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op - ) + if self.fp8_communication: + assert async_op == False, "fp8 all-gather does not support async_op!" + from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 + + work = all_gather_into_tensor_flat_fp8( + self.cuda_global_chunk, self.cuda_shard, self.cuda_global_chunk.shape, self.torch_pg + ) + else: + work = dist.all_gather_into_tensor( + self.cuda_global_chunk, self.cuda_shard, self.torch_pg, async_op=async_op + ) self.cuda_shard = None self.is_gathered = True diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index d0e1755f4..06f9b6d18 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -26,6 +26,7 @@ class ChunkManager: init_device: Optional[torch.device] = None, reuse_fp16_chunk: bool = True, max_prefetch: int = 0, + fp8_communication: bool = False, ) -> None: self.device = init_device or get_accelerator().get_current_device() self.dp_degree_chunk_size_dict: Dict[int, int] = dict() @@ -44,6 +45,7 @@ class ChunkManager: self.accumulating_grads = False self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device()) self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None + self.fp8_communication = fp8_communication def register_tensor( self, @@ -101,6 +103,8 @@ class ChunkManager: extra_dp_group=extra_dp_group, **chunk_kwargs, ) + if self.fp8_communication: + chunk.fp8_communication = True chunk_group.append(chunk) chunk.append_tensor(tensor) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 80b2c7961..0b2039a4d 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -98,6 +98,7 @@ class GeminiDDP(ModelWrapper): extra_dp_group: Optional[ProcessGroup] = None, verbose: bool = False, enable_async_reduce: bool = True, + fp8_communication: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -122,6 +123,8 @@ class GeminiDDP(ModelWrapper): verbose=verbose, max_prefetch=max_prefetch, ) + if fp8_communication: + self.chunk_manager.fp8_communication = True self.gemini_manager = GeminiManager( placement_policy, self.chunk_manager, diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 8a59ab683..f048abdd2 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -179,7 +179,7 @@ def main(): "--plugin", type=str, default="torch_ddp", - choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], + choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel", "torch_fsdp"], help="plugin to use", ) parser.add_argument( @@ -215,9 +215,9 @@ def main(): if args.plugin == "torch_ddp_fp16": booster_kwargs["mixed_precision"] = "fp16" if args.plugin.startswith("torch_ddp"): - plugin = TorchDDPPlugin() + plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm) elif args.plugin == "gemini": - plugin = GeminiPlugin(initial_scale=2**5) + plugin = GeminiPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) elif args.plugin == "hybrid_parallel": @@ -235,6 +235,17 @@ def main(): initial_scale=1, fp8_communication=args.use_fp8_comm, ) + elif args.plugin == "torch_fsdp": + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + + from colossalai.booster.plugin import TorchFSDPPlugin + + plugin = TorchFSDPPlugin( + mixed_precision=MixedPrecision( + param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 + ), + fp8_communication=args.use_fp8_comm, + ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index f447adf69..e9f7203e9 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -212,7 +212,7 @@ def main(): if args.plugin == "torch_ddp_fp16": booster_kwargs["mixed_precision"] = "fp16" if args.plugin.startswith("torch_ddp"): - plugin = TorchDDPPlugin() + plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm) elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == "low_level_zero": diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index e530e2d6a..2bd9671d8 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -98,7 +98,7 @@ def main(): parser.add_argument("--disable-async-reduce", action="store_true", help="Disable the asynchronous reduce operation") parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") - parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") args = parser.parse_args() colossalai.launch_from_torch() @@ -158,6 +158,7 @@ def main(): buffer_dtype=torch.float16, ), param_init_fn=empty_init(), + fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -165,7 +166,8 @@ def main(): param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, - ) + ), + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp_cpu": if use_empty_init: @@ -177,6 +179,7 @@ def main(): ), cpu_offload=CPUOffload(offload_params=True), param_init_fn=empty_init(), + fp8_communication=args.use_fp8_comm, ) else: plugin = TorchFSDPPlugin( @@ -186,6 +189,7 @@ def main(): buffer_dtype=torch.float16, ), cpu_offload=CPUOffload(offload_params=True), + fp8_communication=args.use_fp8_comm, ) elif args.plugin == "3d": plugin = HybridParallelPlugin( @@ -200,9 +204,9 @@ def main(): enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", + dp_outside=False, overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, - overlap_allgather=args.overlap_allgather, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -293,7 +297,7 @@ def main(): with get_profile_context( args.profile, args.ignore_steps, - 1, # avoid creating massive log files + len(dataloader) - 1, save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: diff --git a/tests/test_fp8/test_fp8_cast.py b/tests/test_fp8/test_fp8_cast.py new file mode 100644 index 000000000..db9a909e6 --- /dev/null +++ b/tests/test_fp8/test_fp8_cast.py @@ -0,0 +1,26 @@ +import torch +from torch.testing import assert_close + +from colossalai.accelerator import get_accelerator +from colossalai.quantization.fp8 import cast_from_fp8, cast_from_fp8_pipeline, cast_to_fp8, cast_to_fp8_pipeline +from colossalai.testing import parameterize + + +@parameterize("shape", [(100, 10), (10, 100), (3, 7), (2, 1), (1, 2), (2, 2), (4, 2), (5,), (4,), (2,)]) +@parameterize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +@parameterize("fp8_format", ["e4m3", "e5m2"]) +def test_fp8_cast(shape, dtype, fp8_format): + x = torch.rand(shape, dtype=dtype, device=get_accelerator().get_current_device()) + ret, scale_inv = cast_to_fp8(x, fp8_format=fp8_format) + out = cast_from_fp8(ret, scale_inv, x.dtype) + assert_close(out, x, rtol=0.1, atol=0.1) + + if x.size(-1) % 2 == 0: + inp_dict = {"hidden_states": x.clone()} + cast_to_fp8_pipeline(inp_dict) + cast_from_fp8_pipeline(inp_dict) + assert_close(inp_dict["hidden_states"], x, rtol=0.1, atol=0.1) + + +if __name__ == "__main__": + test_fp8_cast() diff --git a/tests/test_fp8/test_fp8_ddp_comm_hook.py b/tests/test_fp8/test_fp8_ddp_comm_hook.py new file mode 100644 index 000000000..9bdfe17a1 --- /dev/null +++ b/tests/test_fp8/test_fp8_ddp_comm_hook.py @@ -0,0 +1,87 @@ +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def demo_basic(rank, world_size): + print(f"Running basic DDP example on rank {rank}.") + setup(rank, world_size) + + def get_grads_after_one_iteration(hook=None): + torch.manual_seed(0) + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + + ddp_model = DDP(model, device_ids=[rank]) + + if hook is not None: + ddp_model.register_comm_hook(None, hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = ddp_model(torch.randn(20, 10)) + labels = torch.randn(20, 5).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + torch.distributed.barrier() + + grad_dict = {} + for name, params in ddp_model.named_parameters(): + grad_dict[name] = params.grad + return grad_dict + + from colossalai.quantization.fp8 import fp8_compress_ddp_grad_comm_hook_async, fp8_compress_ddp_grad_comm_hook_sync + + grad_dict = get_grads_after_one_iteration() + for hook in [fp8_compress_ddp_grad_comm_hook_sync, fp8_compress_ddp_grad_comm_hook_async]: + grad_dict_w_hook = get_grads_after_one_iteration(hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + + cleanup() + + +def run_demo(demo_fn, world_size): + mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + world_size = n_gpus + run_demo(demo_basic, world_size) diff --git a/tests/test_fp8/test_fp8_fsdp_comm_hook.py b/tests/test_fp8/test_fp8_fsdp_comm_hook.py new file mode 100644 index 000000000..3d0660961 --- /dev/null +++ b/tests/test_fp8/test_fp8_fsdp_comm_hook.py @@ -0,0 +1,107 @@ +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.optim as optim +from packaging import version +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.testing import assert_close + +from colossalai import launch +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn + +# example modified from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html + + +def cleanup(): + dist.destroy_process_group() + + +class ToyModel(nn.Module): + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(100, 100) + self.relu = nn.ReLU() + self.net2 = nn.Linear(100, 50) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +@parameterize("mode", ["grad", "params"]) +def run_model(mode): + rank = dist.get_rank() + + from colossalai.quantization.utils import patch_fsdp_params_comm_hook + + patch_fsdp_params_comm_hook() + + def get_grads_after_one_iteration(grad_hook=None, params_hook=None): + torch.manual_seed(0) + # create model and move it to GPU with id rank + model = ToyModel().to(rank) + fsdp_model = FSDP(model) + + if grad_hook is not None: + fsdp_model.register_comm_hook(None, grad_hook) + + if params_hook is not None: + fsdp_model.register_params_comm_hook(None, params_hook) + + loss_fn = nn.MSELoss() + optimizer = optim.SGD(fsdp_model.parameters(), lr=0.001) + + optimizer.zero_grad() + outputs = fsdp_model(torch.randn(20, 100)) + labels = torch.randn(20, 50).to(rank) + loss_fn(outputs, labels).backward() + optimizer.step() + + torch.distributed.barrier() + + grad_dict = {} + for name, params in fsdp_model.named_parameters(): + grad_dict[name] = params.grad + return grad_dict + + from colossalai.quantization.fp8 import fp8_compress_fsdp_grad_comm_hook, fp8_compress_fsdp_params_comm_hook + + if mode == "grad": + grad_dict = get_grads_after_one_iteration() + for hook in [ + fp8_compress_fsdp_grad_comm_hook, + ]: + grad_dict_w_hook = get_grads_after_one_iteration(grad_hook=hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + elif mode == "params": + grad_dict = get_grads_after_one_iteration() + for hook in [ + fp8_compress_fsdp_params_comm_hook, + ]: + grad_dict_w_hook = get_grads_after_one_iteration(params_hook=hook) + if dist.get_rank() == 0: + for name in grad_dict: + assert_close(grad_dict[name], grad_dict_w_hook[name], rtol=0.1, atol=0.1) + else: + raise NotImplementedError + + +def demo_basic(rank, world_size, port): + print(f"Running basic FSDP example on rank {rank}.") + launch(rank=rank, world_size=world_size, port=port, host="localhost") + run_model() + cleanup() + + +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse("2.2.0"), reason="torch version < 2.2.0.") +@rerun_if_address_is_in_use() +def test_fsdp(): + n_gpus = torch.cuda.device_count() + assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}" + spawn(demo_basic, n_gpus) + + +if __name__ == "__main__": + test_fsdp()