[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)

* [gemini] async grad chunk reduce (all-reduce&reduce-scatter)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] add test

* [gemini] rename func

* [gemini] update llama benchmark

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [gemini] use tensor counter

* [gemini] change default config in GeminiPlugin and GeminiDDP

* [chore] typo

* [gemini] fix sync issue & add test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
botbw
2024-05-24 10:31:16 +08:00
committed by GitHub
parent 85946d4236
commit 2fc85abf43
11 changed files with 130 additions and 45 deletions

View File

@@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("enable_async_reduce", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
placement_config,
keep_gathered: bool,
model_name: str,
master_weights: bool,
use_grad_checkpoint: bool,
enable_async_reduce: bool,
):
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
@@ -81,10 +87,13 @@ def exam_gemini_grad_acc(
pin_memory=True,
enable_gradient_accumulation=True,
master_weights=master_weights,
enable_async_reduce=enable_async_reduce,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
gemini_optim = GeminiOptimizer(
optimizer, gemini_model, initial_scale=1, max_norm=1.0, enable_async_reduce=enable_async_reduce
)
rank = dist.get_rank()