[pipeline] Bert pipeline for shardformer and its tests (#4197)

* add pipeline forward

* complete pipeline forward check

* fix bert forward without pipeline

* fix comments

* discard useless line

* add todo

* clean prints

* fix distribute layers
This commit is contained in:
Jianghai
2023-07-10 13:58:58 +08:00
committed by Hongxin Liu
parent 890774b2fb
commit 1094e0f0d3
5 changed files with 259 additions and 11 deletions

View File

@@ -2,6 +2,7 @@ import copy
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
return org_model.cuda(), sharded_model.cuda()
def build_pipeline_model(model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()