mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)
* [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -164,6 +164,8 @@ class Chunk:
|
||||
self.l2_norm = None
|
||||
|
||||
self.grad_chunk = None
|
||||
# the async all-reduce/reduce-scatter work of this grad chunk (None means sync)
|
||||
self.grad_reduce_work = None
|
||||
|
||||
@property
|
||||
def memory_usage(self) -> Dict[str, int]:
|
||||
@@ -244,7 +246,7 @@ class Chunk:
|
||||
assert self.cuda_shard is not None # only check on CUDA
|
||||
valid_tensor = self.cuda_shard[: self.valid_end]
|
||||
|
||||
return torch.isinf(valid_tensor).any().item() | torch.isnan(valid_tensor).any().item()
|
||||
return torch.isinf(valid_tensor).any() | torch.isnan(valid_tensor).any()
|
||||
|
||||
def set_l2_norm(self) -> None:
|
||||
"""Record l2 norm of this chunks on CUDA."""
|
||||
@@ -374,37 +376,49 @@ class Chunk:
|
||||
if self.is_gathered:
|
||||
self.__scatter()
|
||||
|
||||
def reduce(self):
|
||||
def reduce(self, async_op: bool = False):
|
||||
"""Reduce scatter all the gradients. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.is_gathered
|
||||
|
||||
assert self.grad_reduce_work is None
|
||||
if self.pg_size == 1:
|
||||
# tricky code here
|
||||
# just move cuda_global_chunk to cuda_shard
|
||||
# the communication is not necessary
|
||||
self.__scatter()
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||
elif self.keep_gathered:
|
||||
# we use all-reduce here
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg, async_op=async_op)
|
||||
if self.extra_dp_group is not None: # cannot guranatee the order of multiple all-reduce
|
||||
self.wait_async_reduce()
|
||||
self.grad_reduce_work = dist.all_reduce(
|
||||
self.cuda_global_chunk, group=self.extra_dp_group, async_op=async_op
|
||||
)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
self.grad_reduce_work = dist.reduce_scatter(
|
||||
self.cuda_shard, input_list, group=self.torch_pg, async_op=async_op
|
||||
)
|
||||
|
||||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
|
||||
self.wait_async_reduce()
|
||||
self.grad_reduce_work = dist.all_reduce(self.cuda_shard, group=self.extra_dp_group, async_op=async_op)
|
||||
|
||||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
self.__update_tensors_state(TensorState.HOLD)
|
||||
|
||||
def wait_async_reduce(self) -> None:
|
||||
if self.grad_reduce_work is not None:
|
||||
self.grad_reduce_work.wait()
|
||||
self.grad_reduce_work = None
|
||||
|
||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||
"""
|
||||
Make a transition of the tensor into the next state.
|
||||
|
Reference in New Issue
Block a user