diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bc9425a0b..62046bc36 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase): else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) + # sync gradients across DP * SP ranks + # sync gradients across DP * SP ranks + # Apply Hybrid ZeRO across DP * SP ranks + if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): + self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) + self.dp_size = get_world_size(self.mixed_dp_group) + else: + self.mixed_dp_group = self.dp_group + self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, @@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase): use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or ( self.dp_size == 1 and self.pp_size == 1 ) - # sync gradients across DP * SP ranks - # sync gradients across DP * SP ranks - # Apply Hybrid ZeRO across DP * SP ranks - if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode): - dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis]) - self.dp_size = get_world_size(dp_group) - else: - dp_group = self.dp_group model = HybridParallelModule( model, precision=self.precision, shard_config=self.shard_config, - dp_group=dp_group, + dp_group=self.mixed_dp_group, tp_group=self.tp_group, sp_group=self.sp_group, use_ddp=use_ddp, @@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase): model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info, - dp_process_group=dp_group, + dp_process_group=self.mixed_dp_group, tp_process_group=self.tp_group, pp_process_group=self.pp_group, verbose=True, @@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) + return HybridParallelCheckpointIO( + self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage + ) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index eb17d4c10..c547d3f43 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -4,17 +4,34 @@ import numpy as np import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.distributed.device_mesh import DeviceMesh from colossalai.quantization.fp8 import all_gather_fp8 from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield list(lst)[i : i + n] + + class TensorBucket: def __init__(self, size): self._max_size = size self._current_size = 0 self._bucket = [] self._write_back_pairs = {} + self._allgather_handle = None + self._allgather_buffer = None + world_size = dist.get_world_size() + self.mesh = DeviceMesh( + device_type="cuda", + mesh=list(chunks(range(world_size), torch.cuda.device_count())), + mesh_dim_names=["internode", "intranode"], + ) + self.internode_pg = self.mesh["internode"].get_group() + self.intranode_pg = self.mesh["intranode"].get_group() @property def max_size(self): @@ -53,6 +70,9 @@ class TensorBucket: self._bucket = [] self._current_size = 0 self._write_back_pairs = {} + # del self._allgather_buffer + self._allgather_buffer = None + self._allgather_handle = None def flatten(self): return _flatten_dense_tensors(self._bucket) @@ -65,6 +85,74 @@ class TensorBucket: for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) + def internode_all_gather_async(self, group=None, fp8_communication: bool = False): + assert fp8_communication is False, "fp8 communication is not supported yet" + # assert group is None, "internode_all_gather_async only support default group" + flat = self.flatten() + + if isinstance(group, tuple): + world_size = np.prod([dist.get_world_size(pg) for pg in group]) + else: + world_size = dist.get_world_size(group) + + n_gpus = torch.cuda.device_count() + # print("==debug== flat", flat.dtype) + self._allgather_buffer = torch.empty(flat.numel() * world_size // n_gpus, device=flat.device, dtype=flat.dtype) + self._allgather_handle = all_gather_into_flat_tensor_nd( + self._allgather_buffer, flat, group=self.internode_pg, async_op=True + ) + + def intranode_allgather_and_write_back(self, group=None): + assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty" + assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty" + self._allgather_handle.wait() + dist.get_world_size(group) + n_gpus = torch.cuda.device_count() + + local_chunk = self._allgather_buffer + # print("==debug== local_chunk", local_chunk.dtype) + self._allgather_buffer = torch.empty( + local_chunk.numel() * n_gpus, device=local_chunk.device, dtype=local_chunk.dtype + ) + all_gather_into_flat_tensor_nd(self._allgather_buffer, local_chunk, group=self.intranode_pg, async_op=False) + del local_chunk + self.write_back_and_empty() + + def all_gather_async(self, group=None, fp8_communication: bool = False): + assert fp8_communication is False, "fp8 communication is not supported yet" + + flat = self.flatten() + if isinstance(group, tuple): + world_size = np.prod([dist.get_world_size(pg) for pg in group]) + else: + world_size = dist.get_world_size(group) + self._allgather_buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype) + self._allgather_handle = all_gather_into_flat_tensor_nd( + self._allgather_buffer, flat, group=group, async_op=True + ) + + def write_back_and_empty(self, group=None): + assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty" + assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty" + + if isinstance(group, tuple): + world_size = np.prod([dist.get_world_size(pg) for pg in group]) + else: + world_size = dist.get_world_size(group) + self._allgather_handle.wait() + unflat_buffers = [self.unflatten(buffer) for buffer in self._allgather_buffer.chunk(world_size)] + # transpose the list of list + unflat_buffers = list(map(list, zip(*unflat_buffers))) + for unflat_shards, tensor in zip(unflat_buffers, self._bucket): + write_back_tensor = self._write_back_pairs[tensor] + rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()] + if write_back_tensor.is_contiguous(): + rec_tensor = rec_tensor.view_as(write_back_tensor) + else: + rec_tensor = rec_tensor.reshape_as(write_back_tensor) + write_back_tensor.data.copy_(rec_tensor) + self.empty() + def all_gather(self, group=None, fp8_communication: bool = False): flat = self.flatten() if isinstance(group, tuple): diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 6b69ab133..0b762466e 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -237,6 +237,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): elif self._dtype is torch.bfloat16: self.mixed_precision_mixin = BF16MixedPrecisionMixin() self._current_grad_norm: Optional[float] = None + self._all_gather_async_in_bucket = False + self._2d_allgather = False def __del__(self): for hook in self.grad_handles: @@ -594,8 +596,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for group_id in range(self.num_param_groups): release_param_grad(self._master_param_groups_of_current_rank[group_id]) + import os + + scale = int(os.environ.get("ALL_GATHER_BUCKET_SCALE", 1)) self.pg_to_tensor_bucket = { - pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list + pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size * scale) for pg in self.pg_to_param_list } device = get_accelerator().get_current_device() @@ -641,10 +646,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if not self._overlap_allgather: for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): if not tensor_bucket.is_empty(): - tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) + if self._all_gather_async_in_bucket: + if self._2d_allgather: + tensor_bucket.internode_all_gather_async(pg, fp8_communication=self._fp8_communication) + else: + tensor_bucket.all_gather_async(pg, fp8_communication=self._fp8_communication) + else: + tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication) del params_to_gather_buffer + def params_all_gather_write_back(self): + if self._all_gather_async_in_bucket and hasattr(self, "pg_to_tensor_bucket"): + for pg, tensor_bucket in self.pg_to_tensor_bucket.items(): + if not tensor_bucket.is_empty(): + if self._2d_allgather: + tensor_bucket.intranode_allgather_and_write_back(pg) + else: + tensor_bucket.write_back_and_empty(pg) + def _compute_grad_norm( self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2 ) -> float: