Merge pull request #5749 from hpcaitech/prefetch

[Gemini] Prefetch next chunk before each op
This commit is contained in:
botbw
2024-05-29 15:35:54 +08:00
committed by GitHub
15 changed files with 239 additions and 65 deletions

View File

@@ -40,6 +40,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
@parameterize("max_prefetch", [0, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_gpt_fwd_bwd(
placement_config,
@@ -47,6 +48,7 @@ def exam_gpt_fwd_bwd(
model_name: str,
use_grad_checkpoint: bool = False,
master_weights: bool = True,
max_prefetch: int = 0,
enable_async_reduce=True,
):
init_device = get_accelerator().get_current_device()
@@ -77,6 +79,7 @@ def exam_gpt_fwd_bwd(
pin_memory=True,
**placement_config,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)

View File

@@ -50,6 +50,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("max_prefetch", [0, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_gemini_grad_acc(
placement_config,
@@ -57,6 +58,7 @@ def exam_gemini_grad_acc(
model_name: str,
master_weights: bool,
use_grad_checkpoint: bool,
max_prefetch: int,
enable_async_reduce: bool,
):
init_device = get_accelerator().get_current_device()
@@ -87,6 +89,7 @@ def exam_gemini_grad_acc(
pin_memory=True,
enable_gradient_accumulation=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
**placement_config,
)

View File

@@ -52,8 +52,11 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False])
@parameterize("max_prefetch", [0, 1, 4])
@parameterize("enable_async_reduce", [False, True])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool, enable_async_reduce: bool):
def exam_grad_clipping(
placement_config, model_name: str, master_weights: bool, max_prefetch: int, enable_async_reduce: bool
):
set_seed(1912)
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
@@ -85,6 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool,
chunk_init_device=init_device,
pin_memory=True,
master_weights=master_weights,
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
**placement_config,
)