mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-15 03:07:26 +00:00
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI into dev/zero_bubble
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
@@ -72,6 +73,9 @@ class MlpModel(nn.Module):
|
||||
else:
|
||||
return {"hidden_states": held_layers(hidden_states)}
|
||||
|
||||
def no_sync(self):
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def assert_optim_param_groups(optim_base_param_groups, optim_pp_param_groups):
|
||||
for (key_base, val_base), (key_pp, val_pp) in zip(optim_base_param_groups.items(), optim_pp_param_groups.items()):
|
||||
|
||||
@@ -114,14 +114,12 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
# check last hidden state & loss
|
||||
check_flag = False
|
||||
if stage_manager is None:
|
||||
if (
|
||||
(stage_manager is None)
|
||||
or (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True))
|
||||
or (not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True))
|
||||
):
|
||||
check_flag = True
|
||||
else:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
check_flag = True
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
check_flag = True
|
||||
if check_flag:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
@@ -292,6 +290,19 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"parallel_output": False,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "zbv",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 1,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"parallel_output": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
|
||||
Reference in New Issue
Block a user