overlap allgather draft

This commit is contained in:
BurkeHulk 2025-03-05 16:13:49 +08:00
parent 2254d8ba64
commit fefb4b448a
3 changed files with 124 additions and 13 deletions

View File

@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase):
else:
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(self.mixed_dp_group)
else:
self.mixed_dp_group = self.dp_group
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
sequence_parallel_process_group=self.sp_group,
@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase):
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
self.dp_size == 1 and self.pp_size == 1
)
# sync gradients across DP * SP ranks
# sync gradients across DP * SP ranks
# Apply Hybrid ZeRO across DP * SP ranks
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
self.dp_size = get_world_size(dp_group)
else:
dp_group = self.dp_group
model = HybridParallelModule(
model,
precision=self.precision,
shard_config=self.shard_config,
dp_group=dp_group,
dp_group=self.mixed_dp_group,
tp_group=self.tp_group,
sp_group=self.sp_group,
use_ddp=use_ddp,
@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase):
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
dp_process_group=dp_group,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
verbose=True,
@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase):
)
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
return HybridParallelCheckpointIO(
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
)
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (

View File

@ -4,17 +4,34 @@ import numpy as np
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.device_mesh import DeviceMesh
from colossalai.quantization.fp8 import all_gather_fp8
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield list(lst)[i : i + n]
class TensorBucket:
def __init__(self, size):
self._max_size = size
self._current_size = 0
self._bucket = []
self._write_back_pairs = {}
self._allgather_handle = None
self._allgather_buffer = None
world_size = dist.get_world_size()
self.mesh = DeviceMesh(
device_type="cuda",
mesh=list(chunks(range(world_size), torch.cuda.device_count())),
mesh_dim_names=["internode", "intranode"],
)
self.internode_pg = self.mesh["internode"].get_group()
self.intranode_pg = self.mesh["intranode"].get_group()
@property
def max_size(self):
@ -53,6 +70,9 @@ class TensorBucket:
self._bucket = []
self._current_size = 0
self._write_back_pairs = {}
# del self._allgather_buffer
self._allgather_buffer = None
self._allgather_handle = None
def flatten(self):
return _flatten_dense_tensors(self._bucket)
@ -65,6 +85,74 @@ class TensorBucket:
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def internode_all_gather_async(self, group=None, fp8_communication: bool = False):
assert fp8_communication is False, "fp8 communication is not supported yet"
# assert group is None, "internode_all_gather_async only support default group"
flat = self.flatten()
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
n_gpus = torch.cuda.device_count()
# print("==debug== flat", flat.dtype)
self._allgather_buffer = torch.empty(flat.numel() * world_size // n_gpus, device=flat.device, dtype=flat.dtype)
self._allgather_handle = all_gather_into_flat_tensor_nd(
self._allgather_buffer, flat, group=self.internode_pg, async_op=True
)
def intranode_allgather_and_write_back(self, group=None):
assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty"
assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty"
self._allgather_handle.wait()
dist.get_world_size(group)
n_gpus = torch.cuda.device_count()
local_chunk = self._allgather_buffer
# print("==debug== local_chunk", local_chunk.dtype)
self._allgather_buffer = torch.empty(
local_chunk.numel() * n_gpus, device=local_chunk.device, dtype=local_chunk.dtype
)
all_gather_into_flat_tensor_nd(self._allgather_buffer, local_chunk, group=self.intranode_pg, async_op=False)
del local_chunk
self.write_back_and_empty()
def all_gather_async(self, group=None, fp8_communication: bool = False):
assert fp8_communication is False, "fp8 communication is not supported yet"
flat = self.flatten()
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
self._allgather_buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
self._allgather_handle = all_gather_into_flat_tensor_nd(
self._allgather_buffer, flat, group=group, async_op=True
)
def write_back_and_empty(self, group=None):
assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty"
assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty"
if isinstance(group, tuple):
world_size = np.prod([dist.get_world_size(pg) for pg in group])
else:
world_size = dist.get_world_size(group)
self._allgather_handle.wait()
unflat_buffers = [self.unflatten(buffer) for buffer in self._allgather_buffer.chunk(world_size)]
# transpose the list of list
unflat_buffers = list(map(list, zip(*unflat_buffers)))
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
write_back_tensor = self._write_back_pairs[tensor]
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
if write_back_tensor.is_contiguous():
rec_tensor = rec_tensor.view_as(write_back_tensor)
else:
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
write_back_tensor.data.copy_(rec_tensor)
self.empty()
def all_gather(self, group=None, fp8_communication: bool = False):
flat = self.flatten()
if isinstance(group, tuple):

View File

@ -237,6 +237,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
self._current_grad_norm: Optional[float] = None
self._all_gather_async_in_bucket = False
self._2d_allgather = False
def __del__(self):
for hook in self.grad_handles:
@ -594,8 +596,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
import os
scale = int(os.environ.get("ALL_GATHER_BUCKET_SCALE", 1))
self.pg_to_tensor_bucket = {
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size * scale) for pg in self.pg_to_param_list
}
device = get_accelerator().get_current_device()
@ -641,10 +646,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
if self._all_gather_async_in_bucket:
if self._2d_allgather:
tensor_bucket.internode_all_gather_async(pg, fp8_communication=self._fp8_communication)
else:
tensor_bucket.all_gather_async(pg, fp8_communication=self._fp8_communication)
else:
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
del params_to_gather_buffer
def params_all_gather_write_back(self):
if self._all_gather_async_in_bucket and hasattr(self, "pg_to_tensor_bucket"):
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
if self._2d_allgather:
tensor_bucket.intranode_allgather_and_write_back(pg)
else:
tensor_bucket.write_back_and_empty(pg)
def _compute_grad_norm(
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
) -> float: