[pipeline] Add Simplified Alpa DP Partition (#2507)

* add alpa dp split

* add alpa dp split

* use fwd+bwd instead of fwd only

---------

Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
Ziyue Jiang
2023-03-07 10:34:31 +08:00
committed by GitHub
parent b42d3d28ed
commit 400f63012e
4 changed files with 197 additions and 15 deletions

View File

@@ -8,11 +8,16 @@ from torch import nn
from tqdm import tqdm
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
from colossalai.fx.passes.adding_split_node_pass import (
avgnode_split_pass,
gpipe_dp_split_pass,
split_with_split_nodes_pass,
)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.pipeline.rpc.utils import rpc_run
@@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
# Create annotated model which is noted where to be splitted.
def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = avgnode_split_pass(gm, stage_num)
interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
interp = MetaInfoProp(gm)
interp.run(*interp_meta_args)
#annotated_model = avgnode_split_pass(gm, num_stages)
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
return annotated_model
def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
@@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
return split_submodules[pp_rank + 1]
def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
return module
@@ -103,17 +120,19 @@ def run_master(args):
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
# create model
logger.info(f'start model_builder')
model = model_builder(model_type)(checkpoint=False)
logger.info(f'end model_builder')
# set 1f1b pipeline engine
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=1,
criterion=criterion,
metric=None,
checkpoint=False)
pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=1,
criterion=criterion,
metric=None,
checkpoint=False)
partition_numels = pp_engine.remote_numels()
for rank, numel in partition_numels.items():