mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[gemini] async grad chunk reduce (all-reduce&reduce-scatter) (#5713)
* [gemini] async grad chunk reduce (all-reduce&reduce-scatter) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] add test * [gemini] rename func * [gemini] update llama benchmark * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [gemini] use tensor counter * [gemini] change default config in GeminiPlugin and GeminiDDP * [chore] typo * [gemini] fix sync issue & add test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -73,7 +73,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
|
||||
@parameterize("model_name", TEST_MODELS)
|
||||
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
|
||||
@parameterize("master_weights", [True, False])
|
||||
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
|
||||
@parameterize("enable_async_reduce", [False, True])
|
||||
def exam_model_step(
|
||||
placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool, enable_async_reduce=True
|
||||
):
|
||||
set_seed(42)
|
||||
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
|
||||
iter(model_zoo.get_sub_registry(model_name).values())
|
||||
@@ -96,7 +99,12 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
|
||||
config_dict[world_size]["chunk_size"] = 5000
|
||||
config_dict[world_size]["keep_gathered"] = False
|
||||
model = GeminiDDP(
|
||||
model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights
|
||||
model,
|
||||
config_dict,
|
||||
**placement_config,
|
||||
mixed_precision=mixed_precision,
|
||||
master_weights=master_weights,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
|
Reference in New Issue
Block a user