mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-07 07:58:27 +00:00
overlap allgather draft
This commit is contained in:
parent
2254d8ba64
commit
fefb4b448a
@ -1188,6 +1188,15 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
else:
|
||||
self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis)
|
||||
|
||||
# sync gradients across DP * SP ranks
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
self.mixed_dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.dp_size = get_world_size(self.mixed_dp_group)
|
||||
else:
|
||||
self.mixed_dp_group = self.dp_group
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
sequence_parallel_process_group=self.sp_group,
|
||||
@ -1298,19 +1307,11 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
|
||||
self.dp_size == 1 and self.pp_size == 1
|
||||
)
|
||||
# sync gradients across DP * SP ranks
|
||||
# sync gradients across DP * SP ranks
|
||||
# Apply Hybrid ZeRO across DP * SP ranks
|
||||
if self.enable_sequence_parallelism and not is_share_sp_tp(self.sequence_parallelism_mode):
|
||||
dp_group = self.pg_mesh.create_group_along_axis([self.dp_axis, self.sp_axis])
|
||||
self.dp_size = get_world_size(dp_group)
|
||||
else:
|
||||
dp_group = self.dp_group
|
||||
model = HybridParallelModule(
|
||||
model,
|
||||
precision=self.precision,
|
||||
shard_config=self.shard_config,
|
||||
dp_group=dp_group,
|
||||
dp_group=self.mixed_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
sp_group=self.sp_group,
|
||||
use_ddp=use_ddp,
|
||||
@ -1359,7 +1360,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
param_info=param_info,
|
||||
dp_process_group=dp_group,
|
||||
dp_process_group=self.mixed_dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
verbose=True,
|
||||
@ -1488,7 +1489,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage)
|
||||
return HybridParallelCheckpointIO(
|
||||
self.mixed_dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage
|
||||
)
|
||||
|
||||
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert (
|
||||
|
@ -4,17 +4,34 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
from colossalai.quantization.fp8 import all_gather_fp8
|
||||
from colossalai.zero.low_level._utils import all_gather_into_flat_tensor_nd
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield list(lst)[i : i + n]
|
||||
|
||||
|
||||
class TensorBucket:
|
||||
def __init__(self, size):
|
||||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
self._write_back_pairs = {}
|
||||
self._allgather_handle = None
|
||||
self._allgather_buffer = None
|
||||
world_size = dist.get_world_size()
|
||||
self.mesh = DeviceMesh(
|
||||
device_type="cuda",
|
||||
mesh=list(chunks(range(world_size), torch.cuda.device_count())),
|
||||
mesh_dim_names=["internode", "intranode"],
|
||||
)
|
||||
self.internode_pg = self.mesh["internode"].get_group()
|
||||
self.intranode_pg = self.mesh["intranode"].get_group()
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
@ -53,6 +70,9 @@ class TensorBucket:
|
||||
self._bucket = []
|
||||
self._current_size = 0
|
||||
self._write_back_pairs = {}
|
||||
# del self._allgather_buffer
|
||||
self._allgather_buffer = None
|
||||
self._allgather_handle = None
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
@ -65,6 +85,74 @@ class TensorBucket:
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
||||
def internode_all_gather_async(self, group=None, fp8_communication: bool = False):
|
||||
assert fp8_communication is False, "fp8 communication is not supported yet"
|
||||
# assert group is None, "internode_all_gather_async only support default group"
|
||||
flat = self.flatten()
|
||||
|
||||
if isinstance(group, tuple):
|
||||
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||
else:
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
# print("==debug== flat", flat.dtype)
|
||||
self._allgather_buffer = torch.empty(flat.numel() * world_size // n_gpus, device=flat.device, dtype=flat.dtype)
|
||||
self._allgather_handle = all_gather_into_flat_tensor_nd(
|
||||
self._allgather_buffer, flat, group=self.internode_pg, async_op=True
|
||||
)
|
||||
|
||||
def intranode_allgather_and_write_back(self, group=None):
|
||||
assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty"
|
||||
assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty"
|
||||
self._allgather_handle.wait()
|
||||
dist.get_world_size(group)
|
||||
n_gpus = torch.cuda.device_count()
|
||||
|
||||
local_chunk = self._allgather_buffer
|
||||
# print("==debug== local_chunk", local_chunk.dtype)
|
||||
self._allgather_buffer = torch.empty(
|
||||
local_chunk.numel() * n_gpus, device=local_chunk.device, dtype=local_chunk.dtype
|
||||
)
|
||||
all_gather_into_flat_tensor_nd(self._allgather_buffer, local_chunk, group=self.intranode_pg, async_op=False)
|
||||
del local_chunk
|
||||
self.write_back_and_empty()
|
||||
|
||||
def all_gather_async(self, group=None, fp8_communication: bool = False):
|
||||
assert fp8_communication is False, "fp8 communication is not supported yet"
|
||||
|
||||
flat = self.flatten()
|
||||
if isinstance(group, tuple):
|
||||
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||
else:
|
||||
world_size = dist.get_world_size(group)
|
||||
self._allgather_buffer = torch.empty(flat.numel() * world_size, device=flat.device, dtype=flat.dtype)
|
||||
self._allgather_handle = all_gather_into_flat_tensor_nd(
|
||||
self._allgather_buffer, flat, group=group, async_op=True
|
||||
)
|
||||
|
||||
def write_back_and_empty(self, group=None):
|
||||
assert self._allgather_buffer is not None, "all_gather_async must be called before write_back_and_empty"
|
||||
assert self._allgather_handle is not None, "all_gather_async must be called before write_back_and_empty"
|
||||
|
||||
if isinstance(group, tuple):
|
||||
world_size = np.prod([dist.get_world_size(pg) for pg in group])
|
||||
else:
|
||||
world_size = dist.get_world_size(group)
|
||||
self._allgather_handle.wait()
|
||||
unflat_buffers = [self.unflatten(buffer) for buffer in self._allgather_buffer.chunk(world_size)]
|
||||
# transpose the list of list
|
||||
unflat_buffers = list(map(list, zip(*unflat_buffers)))
|
||||
for unflat_shards, tensor in zip(unflat_buffers, self._bucket):
|
||||
write_back_tensor = self._write_back_pairs[tensor]
|
||||
rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()]
|
||||
if write_back_tensor.is_contiguous():
|
||||
rec_tensor = rec_tensor.view_as(write_back_tensor)
|
||||
else:
|
||||
rec_tensor = rec_tensor.reshape_as(write_back_tensor)
|
||||
write_back_tensor.data.copy_(rec_tensor)
|
||||
self.empty()
|
||||
|
||||
def all_gather(self, group=None, fp8_communication: bool = False):
|
||||
flat = self.flatten()
|
||||
if isinstance(group, tuple):
|
||||
|
@ -237,6 +237,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
self._current_grad_norm: Optional[float] = None
|
||||
self._all_gather_async_in_bucket = False
|
||||
self._2d_allgather = False
|
||||
|
||||
def __del__(self):
|
||||
for hook in self.grad_handles:
|
||||
@ -594,8 +596,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
import os
|
||||
|
||||
scale = int(os.environ.get("ALL_GATHER_BUCKET_SCALE", 1))
|
||||
self.pg_to_tensor_bucket = {
|
||||
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size) for pg in self.pg_to_param_list
|
||||
pg: TensorBucket(self.pg_to_bucket_store[pg].reduce_bucket_size * scale) for pg in self.pg_to_param_list
|
||||
}
|
||||
|
||||
device = get_accelerator().get_current_device()
|
||||
@ -641,10 +646,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
if not self._overlap_allgather:
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
if self._all_gather_async_in_bucket:
|
||||
if self._2d_allgather:
|
||||
tensor_bucket.internode_all_gather_async(pg, fp8_communication=self._fp8_communication)
|
||||
else:
|
||||
tensor_bucket.all_gather_async(pg, fp8_communication=self._fp8_communication)
|
||||
else:
|
||||
tensor_bucket.all_gather(pg, fp8_communication=self._fp8_communication)
|
||||
|
||||
del params_to_gather_buffer
|
||||
|
||||
def params_all_gather_write_back(self):
|
||||
if self._all_gather_async_in_bucket and hasattr(self, "pg_to_tensor_bucket"):
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
if self._2d_allgather:
|
||||
tensor_bucket.intranode_allgather_and_write_back(pg)
|
||||
else:
|
||||
tensor_bucket.write_back_and_empty(pg)
|
||||
|
||||
def _compute_grad_norm(
|
||||
self, dp_pg: Union[ProcessGroup, Tuple[ProcessGroup, ...]], gradients: List[Tensor], norm_type: int = 2
|
||||
) -> float:
|
||||
|
Loading…
Reference in New Issue
Block a user