diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 7b5aec2aa..bfb8930bb 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -293,6 +293,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + fp8_communication: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -315,6 +316,7 @@ class LowLevelZeroPlugin(DPPluginBase): partition_grad=(stage == 2), cpu_offload=cpu_offload, master_weights=master_weights, + fp8_communication=fp8_communication, ) self.lora_enabled = False self.verbose = verbose diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 5b09019b9..d5fd2fe51 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -4,6 +4,8 @@ import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8 + class TensorBucket: def __init__(self, size): @@ -61,11 +63,14 @@ class TensorBucket: for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) - def all_gather(self, group=None): + def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() - buffers = [torch.empty_like(flat) for _ in range(dist.get_world_size(group))] - dist.all_gather(buffers, flat, group=group) - unflat_buffers = [self.unflatten(buffer) for buffer in buffers] + buffer = torch.empty(flat.numel() * dist.get_world_size(group), device=flat.device, dtype=flat.dtype) + if fp8_communication: + all_gather_into_tensor_flat_fp8(buffer, flat, output_shape=buffer.shape, group=group) + else: + dist.all_gather_into_tensor(buffer, flat, group=group) + unflat_buffers = [self.unflatten(buffer) for buffer in buffer.chunk(dist.get_world_size(group))] # transpose the list of list unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index bdc91b51f..1353071c5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -20,6 +20,7 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import ( ) from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.quantization.fp8 import all_gather_into_tensor_flat_fp8, all_reduce_fp8, reduce_scatter_fp8 from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, TensorBucket @@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): dp_process_group: Optional[ProcessGroup] = None, forced_dtype: Optional[torch.dtype] = None, master_weights: bool = True, # master weights + fp8_communication: bool = False, ): super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) @@ -123,6 +125,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): self._overlap_communication = overlap_communication self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype + self._fp8_communication = fp8_communication # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -323,7 +326,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): flat_grads = flat_grads.to(self._communication_dtype) if not self._partition_grads: - dist.all_reduce(flat_grads, group=bucket_store.torch_pg) + if self._fp8_communication: + all_reduce_fp8(flat_grads, group=bucket_store.torch_pg) + else: + dist.all_reduce(flat_grads, group=bucket_store.torch_pg) if flat_grads.dtype != grad_dtype: flat_grads = flat_grads.to(grad_dtype) @@ -333,7 +339,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): else: flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size)) recieved_grad = torch.zeros_like(flat_grads_list[0]) - dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) + if self._fp8_communication: + reduce_scatter_fp8( + recieved_grad, + flat_grads_list, + group=bucket_store.torch_pg, + ) + else: + dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg) if recieved_grad.dtype != grad_dtype: recieved_grad = recieved_grad.to(grad_dtype) @@ -553,18 +566,21 @@ class LowLevelZeroOptimizer(OptimizerWrapper): buffer_tensor = torch.empty_like( torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))]) ) - dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) + if self._fp8_communication: + all_gather_into_tensor_flat_fp8(buffer_tensor, param_to_gather, pg, fp8_format="e4m3") + else: + dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg) working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param)) continue try: self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) except RuntimeError: - self.pg_to_tensor_bucket[pg].all_gather(pg) + self.pg_to_tensor_bucket[pg].all_gather(pg, fp8_communication=self._fp8_communication) self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg) + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float: r"""