mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[pipeline] Bert pipeline for shardformer and its tests (#4197)
* add pipeline forward * complete pipeline forward check * fix bert forward without pipeline * fix comments * discard useless line * add todo * clean prints * fix distribute layers
This commit is contained in:
@@ -2,6 +2,7 @@ import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
@@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def build_pipeline_model(model_fn,
|
||||
stage_manager=None,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
org_model = model_fn()
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
|
Reference in New Issue
Block a user