mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap * [zero] add overlap all-gather flag * [misc] fix typo * [zero] update api
This commit is contained in:
@@ -23,6 +23,7 @@ from colossalai.logging import get_dist_logger
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, TensorBucket
|
||||
from .zero_hook import set_all_gather_handle, wait_all_gather_handle
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
@@ -83,6 +84,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
master_weights: bool = True, # master weights
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
|
||||
@@ -121,6 +123,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# communication params
|
||||
self._overlap_communication = overlap_communication
|
||||
self._overlap_allgather = overlap_allgather
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
@@ -145,6 +148,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# record the padding size of each param
|
||||
self._padding_map = dict()
|
||||
# padded working param is all-gather buffer and it shares the same memory with working param
|
||||
self._working_param_to_padded_working_param = dict()
|
||||
|
||||
# mapping working param and master param
|
||||
self.master_to_working_param = dict()
|
||||
@@ -245,11 +250,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
with torch.no_grad():
|
||||
if padding_size > 0:
|
||||
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
# reset working params' ptr when no master weights
|
||||
if self._master_weights == False:
|
||||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
# # reset working params' ptr when no master weights
|
||||
# if self._master_weights == False:
|
||||
param.data = padding_param[: param.numel()].view(param.shape)
|
||||
else:
|
||||
padding_param = param.data.view(-1)
|
||||
self._working_param_to_padded_working_param[param] = padding_param
|
||||
|
||||
splited_params = padding_param.split(
|
||||
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
|
||||
@@ -258,7 +264,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
# use fp32 when master_weights is True
|
||||
if self._master_weights is True:
|
||||
splited_param_current_rank = splited_params.detach().float().to(device)
|
||||
splited_param_current_rank = splited_params.detach().clone().float().to(device)
|
||||
else:
|
||||
splited_param_current_rank = splited_params
|
||||
|
||||
@@ -549,22 +555,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
param_to_gather = master_param.to(device).to(self._dtype)
|
||||
pg = self.param_to_pg[working_param]
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
buffer_tensor = torch.empty_like(
|
||||
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
|
||||
)
|
||||
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
|
||||
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
padded_working_param = self._working_param_to_padded_working_param[working_param]
|
||||
if self._overlap_allgather:
|
||||
handle = dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg, async_op=True)
|
||||
set_all_gather_handle(working_param, handle)
|
||||
else:
|
||||
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
|
||||
dist.all_gather_into_tensor(padded_working_param, param_to_gather, pg)
|
||||
continue
|
||||
try:
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
except RuntimeError:
|
||||
self.pg_to_tensor_bucket[pg].all_gather(pg)
|
||||
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
if not self._overlap_allgather:
|
||||
for pg, tensor_bucket in self.pg_to_tensor_bucket.items():
|
||||
if not tensor_bucket.is_empty():
|
||||
tensor_bucket.all_gather(pg)
|
||||
|
||||
def _compute_grad_norm(self, dp_pg: ProcessGroup, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
@@ -892,3 +900,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
grad_store = self.pid_to_grad_store[param_id]
|
||||
return grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for param in self._working_param_to_padded_working_param.keys():
|
||||
wait_all_gather_handle(param)
|
||||
|
33
colossalai/zero/low_level/zero_hook.py
Normal file
33
colossalai/zero/low_level/zero_hook.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import List
|
||||
|
||||
from torch._tensor import Tensor
|
||||
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
_ALL_GATHER_HANDLE = "_all_gather_handle"
|
||||
|
||||
|
||||
def wait_all_gather_handle(p):
|
||||
if hasattr(p, _ALL_GATHER_HANDLE):
|
||||
handle = getattr(p, _ALL_GATHER_HANDLE)
|
||||
handle.wait()
|
||||
delattr(p, _ALL_GATHER_HANDLE)
|
||||
|
||||
|
||||
def set_all_gather_handle(p, handle):
|
||||
setattr(p, _ALL_GATHER_HANDLE, handle)
|
||||
|
||||
|
||||
class ZeroOpHook(ColoParamOpHook):
|
||||
def pre_forward(self, params: List[Tensor]) -> None:
|
||||
for p in params:
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def post_forward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params: List[Tensor]) -> None:
|
||||
pass
|
Reference in New Issue
Block a user