[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)

* [gemini] async grad chunk reduce (all-reduce&reduce-scatter)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] add test

* [gemini] rename func

* [gemini] update llama benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] use tensor counter

* [gemini] change default config in GeminiPlugin and GeminiDDP

* [chore] typo

* [gemini] fix sync issue & add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
botbw
2024-05-24 10:31:16 +08:00
committed by GitHub
parent 85946d4236
commit 2fc85abf43
11 changed files with 130 additions and 45 deletions

View File

@@ -41,7 +41,7 @@ class ChunkManager:
self.reuse_fp16_chunk = reuse_fp16_chunk
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = 0
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
def register_tensor(
self,
@@ -143,12 +143,12 @@ class ChunkManager:
chunk = self.tensor_chunk_map[tensor]
chunk.tensor_trans_state(tensor, state)
def reduce_chunk(self, chunk: Chunk) -> bool:
def reduce_chunk(self, chunk: Chunk, async_op: bool = False) -> bool:
"""Reduce or all reduce the chunk."""
if not chunk.can_reduce:
return False
self.__sub_memory_usage(chunk.memory_usage)
chunk.reduce()
chunk.reduce(async_op=async_op)
self.__sub_accessed_chunk(chunk)
self.__add_memory_usage(chunk.memory_usage)
return True
@@ -272,7 +272,7 @@ class ChunkManager:
return grad_chunk
def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk:
"""Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction."""
"""Rearrange gradients accumulated in chunk.grad_chunk, and get prepared for gradient reduction."""
assert chunk.grad_chunk is not None