[zero] support all-gather overlap (#5898)

* [zero] support all-gather overlap

* [zero] add overlap all-gather flag

* [misc] fix typo

* [zero] update api
This commit is contained in:
Hongxin Liu
2024-07-11 18:59:59 +08:00
committed by GitHub
parent dd9e1cdafe
commit c068ef0fa0
7 changed files with 119 additions and 25 deletions

View File

@@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, TensorBucket
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
@@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
dp_process_group: Optional[ProcessGroup] = None,
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# communication params
self._overlap_communication = overlap_communication
self._overlap_allgather = overlap_allgather
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
@@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# record the padding size of each param
self._padding_map = dict()
# padded working param is all-gather buffer and it shares the same memory with working param
self._working_param_to_padded_working_param = dict()
# mapping working param and master param
self.master_to_working_param = dict()
@@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
# # reset working params' ptr when no master weights
# if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
self._working_param_to_padded_working_param[param] = padding_param
splited_params = padding_param.split(
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
@@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params.detach().float().to(device)
splited_param_current_rank = splited_params.detach().clone().float().to(device)
else:
splited_param_current_rank = splited_params
@@ -549,22 +555,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
padded_working_param = self._working_param_to_padded_working_param[working_param]
if self._overlap_allgather:
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
set_all_gather_handle(working_param, handle)
else:
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
self.pg_to_tensor_bucket[pg].all_gather(pg)
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
if not self._overlap_allgather:
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
if not tensor_bucket.is_empty():
tensor_bucket.all_gather(pg)
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
@@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
grad_store = self.pid_to_grad_store[param_id]
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)

View File

@@ -0,0 +1,33 @@
from typing import List
from torch._tensor import Tensor
from colossalai.tensor.param_op_hook import ColoParamOpHook
_ALL_GATHER_HANDLE = "_all_gather_handle"
def wait_all_gather_handle(p):
if hasattr(p, _ALL_GATHER_HANDLE):
handle = getattr(p, _ALL_GATHER_HANDLE)
handle.wait()
delattr(p, _ALL_GATHER_HANDLE)
def set_all_gather_handle(p, handle):
setattr(p, _ALL_GATHER_HANDLE, handle)
class ZeroOpHook(ColoParamOpHook):
def pre_forward(self, params: List[Tensor]) -> None:
for p in params:
wait_all_gather_handle(p)
def post_forward(self, params: List[Tensor]) -> None:
pass
def pre_backward(self, params: List[Tensor]) -> None:
pass
def post_backward(self, params: List[Tensor]) -> None:
pass