mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[bug] fix early return (#5740)
* [bug] fix silly bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [chore] add test for prefetch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -50,8 +50,14 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
|
||||
@parameterize("model_name", ["transformers_gpt_lm"])
|
||||
@parameterize("master_weights", [False, True])
|
||||
@parameterize("use_grad_checkpoint", [False, True])
|
||||
@parameterize("max_prefetch", [0, 1, 4])
|
||||
def exam_gemini_grad_acc(
|
||||
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
|
||||
placement_config,
|
||||
keep_gathered: bool,
|
||||
model_name: str,
|
||||
master_weights: bool,
|
||||
use_grad_checkpoint: bool,
|
||||
max_prefetch: int,
|
||||
):
|
||||
init_device = get_accelerator().get_current_device()
|
||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||
@@ -81,6 +87,7 @@ def exam_gemini_grad_acc(
|
||||
pin_memory=True,
|
||||
enable_gradient_accumulation=True,
|
||||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
**placement_config,
|
||||
)
|
||||
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
|
||||
|
Reference in New Issue
Block a user