mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[plugin] hybrid support zero bubble pipeline (#6060)
* hybrid support zbv * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * [zerobubble]Support ZeroBubble Pipeline (#6034) * [feat] add zerobubble pp (just a frame now); add POC test for dx_dw; add test for zerobubble; * [feat] add dw test; * [fix] fix weight not close; * [update] update text; * [feat] add test run_fwd_bwd automatic scheduling; * [feat] split communication and calculation; fix pop empty send_bwd_buffer error; * [feat] add test for p & p grad; * [feat] add comments for ZBV func; * [fix] rm useless assign and comments; * [fix] fix ci test; add pytest; * [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; * [feat] add apply v_schedule graph; p & p.grad assert err exist; * [fix] update * [feat] fix ci; add assert; * [feat] fix poc format * [feat] fix func name & ci; add comments; * [fix] fix poc test; add comments in poc; * [feat] add optim backward_b_by_grad * [feat] fix optimizer bwd b & w; support return accum loss & output * [feat] add fwd_bwd_step, run_fwd_only; * [fix] fix optim bwd; add license for v_schedule; remove redundant attributes; fix schedule loop "while"--> "for"; add communication dict; * [fix] fix communication_map; * [feat] update test; rm comments; * [fix] rm zbv in hybridplugin * [fix] fix optim bwd; * [fix] fix optim bwd; * [fix] rm output.data after send fwd; * [fix] fix bwd step if condition; remove useless comments and format info; * [fix] fix detach output & release output; * [fix] rm requir_grad for output; * [fix] fix requir grad position and detach position and input&output local buffer append position; * [feat] add memory assertation; * [fix] fix mem check; * [fix] mem assertation' * [fix] fix mem assertation * [fix] fix mem; use a new model shape; only assert mem less and equal than theo; * [fix] fix model zoo import; * [fix] fix redundant detach & clone; add buffer assertation in the end; * [fix] add output_obj_grad assert None at bwd b step; replace input_obj.require_grad_ with treemap; * [fix] update optim state dict assert (include param group & state); fix mem assert after add optim; * [fix] add testcase with microbatch 4; * hybrid support zbv * fix fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update zero_bubble_pp.py * fix * fix-ci * fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <935724073@qq.com>
This commit is contained in:
@@ -157,7 +157,6 @@ def build_model_from_hybrid_plugin(
|
||||
sharded_optimizer = sharded_optim_class(sharded_model.parameters(), lr=1e-3)
|
||||
|
||||
criterion = loss_fn
|
||||
|
||||
plugin = pluggin_cls(**test_config)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
@@ -311,8 +310,16 @@ def check_output_hidden_state(
|
||||
):
|
||||
org_hidden_state = org_output.last_hidden_state
|
||||
|
||||
if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
if stage_manager:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
else:
|
||||
sharded_hidden_state = sharded_output.last_hidden_state
|
||||
|
||||
@@ -390,7 +397,6 @@ def get_grad_tensors_for_check(
|
||||
pass
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
|
||||
|
||||
grad_to_check[suffix] = {
|
||||
"org_grad": org_grad.float(),
|
||||
"shard_grad": shard_grad.float(),
|
||||
|
@@ -7,6 +7,7 @@ from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.schedule.v_schedule import PipelineGraph
|
||||
from colossalai.shardformer import PipelineGradientCheckpointConfig
|
||||
from colossalai.shardformer.layer.utils import Randomizer
|
||||
from colossalai.tensor.d_tensor.api import clear_layout_converter
|
||||
@@ -33,7 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
if enable_gradient_checkpointing:
|
||||
# org_model.gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable()
|
||||
sharded_model.unwrap().gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
@@ -112,12 +113,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
sharded_optimizer.step()
|
||||
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
|
||||
check_flag = False
|
||||
if stage_manager is None:
|
||||
check_flag = True
|
||||
else:
|
||||
if stage_manager.use_zbv:
|
||||
if stage_manager.is_first_stage(ignore_chunk=True):
|
||||
check_flag = True
|
||||
elif stage_manager.is_last_stage(ignore_chunk=True):
|
||||
check_flag = True
|
||||
if check_flag:
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 1e-5, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
if org_model.__class__.__name__ == "LlamaModel":
|
||||
check_output_hidden_state(
|
||||
org_output,
|
||||
@@ -282,10 +291,39 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"pp_style": "zbv",
|
||||
"num_model_chunks": 2,
|
||||
"num_microbatches": 4,
|
||||
"enable_all_optimization": False,
|
||||
"precision": "fp16",
|
||||
"zero_stage": 0,
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"parallel_output": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
def run_llama_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
if test_config.get("pp_style", None) == "zbv":
|
||||
mem_f = 34 * 32 + 5 * 4 * 16
|
||||
mem_w = -32 * 32
|
||||
mem_b = -mem_w - mem_f
|
||||
scheduler_nodes = PipelineGraph(
|
||||
n_stage=test_config["pp_size"],
|
||||
n_micro=test_config["num_microbatches"],
|
||||
f_cost=1000,
|
||||
b_cost=1000,
|
||||
w_cost=1000,
|
||||
c_cost=1,
|
||||
f_mem=mem_f,
|
||||
b_mem=mem_b,
|
||||
w_mem=mem_w,
|
||||
).get_v_schedule()
|
||||
test_config["scheduler_nodes"] = scheduler_nodes
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name:
|
||||
continue
|
||||
|
Reference in New Issue
Block a user