[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)

* [gemini] async grad chunk reduce (all-reduce&reduce-scatter)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] add test

* [gemini] rename func

* [gemini] update llama benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] use tensor counter

* [gemini] change default config in GeminiPlugin and GeminiDDP

* [chore] typo

* [gemini] fix sync issue & add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
botbw
2024-05-24 10:31:16 +08:00
committed by GitHub
parent 85946d4236
commit 2fc85abf43
11 changed files with 130 additions and 45 deletions

View File

@@ -62,10 +62,10 @@ class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
self.module = module
def check_local_overflow(self) -> bool:
return self.module.chunk_manager.overflow_counter > 0
return self.module.chunk_manager.overflow_counter.item() > 0
def pre_zero_grad(self) -> None:
self.module.chunk_manager.overflow_counter = 0
self.module.chunk_manager.overflow_counter.zero_()
class GeminiOptimizer(OptimizerWrapper):