[pipeline] add chatglm (#4363)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* add chatglm

* add

* chatglm

* chatglm

* finish chatglm

* deletes

* fix rmsnorm

* chatglm

* fix chatglm shard

* init
This commit is contained in:
Jianghai
2023-08-04 14:55:31 +08:00
committed by Hongxin Liu
parent b1feeced8e
commit a88e92251d
9 changed files with 1828 additions and 57 deletions

View File

@@ -1,5 +1,6 @@
import copy
from contextlib import nullcontext
from typing import Optional
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -15,6 +16,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.shardformer._utils import getattr_
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
@@ -39,7 +41,8 @@ def build_pipeline_model(model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False):
use_lazy_init: bool = False,
policy: Optional[Policy] = None):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
@@ -54,7 +57,7 @@ def build_pipeline_model(model_fn,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
return org_model.cuda(), sharded_model.cuda()