mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[pipeline] add chatglm (#4363)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * add chatglm * add * chatglm * chatglm * finish chatglm * deletes * fix rmsnorm * chatglm * fix chatglm shard * init
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -15,6 +16,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
|
||||
@@ -39,7 +41,8 @@ def build_pipeline_model(model_fn,
|
||||
stage_manager=None,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
use_lazy_init: bool = False,
|
||||
policy: Optional[Policy] = None):
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
@@ -54,7 +57,7 @@ def build_pipeline_model(model_fn,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy, policy=policy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user