mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[pipeline] Add Simplified Alpa DP Partition (#2507)
* add alpa dp split * add alpa dp split * use fwd+bwd instead of fwd only --------- Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>
This commit is contained in:
@@ -8,11 +8,16 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
|
||||
from colossalai.fx.passes.adding_split_node_pass import (
|
||||
avgnode_split_pass,
|
||||
gpipe_dp_split_pass,
|
||||
split_with_split_nodes_pass,
|
||||
)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.pipeline.middleware.adaptor import get_fx_topology
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
|
||||
from colossalai.pipeline.rpc.utils import rpc_run
|
||||
|
||||
|
||||
@@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
# Create annotated model which is noted where to be splitted.
|
||||
def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
|
||||
tracer = ColoTracer()
|
||||
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
|
||||
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
|
||||
annotated_model = avgnode_split_pass(gm, stage_num)
|
||||
|
||||
interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.run(*interp_meta_args)
|
||||
|
||||
#annotated_model = avgnode_split_pass(gm, num_stages)
|
||||
annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
|
||||
|
||||
return annotated_model
|
||||
|
||||
|
||||
def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
|
||||
annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
|
||||
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
|
||||
topo = get_fx_topology(top_module)
|
||||
for submodule in split_submodules:
|
||||
@@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
||||
return split_submodules[pp_rank + 1]
|
||||
|
||||
|
||||
def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
|
||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
|
||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
|
||||
return module
|
||||
|
||||
|
||||
@@ -103,17 +120,19 @@ def run_master(args):
|
||||
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
# create model
|
||||
logger.info(f'start model_builder')
|
||||
model = model_builder(model_type)(checkpoint=False)
|
||||
logger.info(f'end model_builder')
|
||||
|
||||
# set 1f1b pipeline engine
|
||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False)
|
||||
pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
chunk=1,
|
||||
criterion=criterion,
|
||||
metric=None,
|
||||
checkpoint=False)
|
||||
|
||||
partition_numels = pp_engine.remote_numels()
|
||||
for rank, numel in partition_numels.items():
|
||||
|
Reference in New Issue
Block a user