mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)
* [shardformer] implement policy for all GPT-J models and test * [shardformer] support interleaved pipeline parallel for bert finetune * [shardformer] shardformer support falcon (#4883) * [shardformer]: fix interleaved pipeline for bert model (#5048) * [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093) * Add Mistral support for Shardformer (#5103) * [shardformer] add tests to mistral (#5105) --------- Co-authored-by: Pengtai Xu <henryxu880@gmail.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: eric8607242 <e0928021388@gmail.com>
This commit is contained in:
@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# to check if it is huggingface LlamaRMSNorm
|
||||
if module.__class__.__name__ == "LlamaRMSNorm":
|
||||
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
||||
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
||||
normalized_shape = module.weight.shape[0]
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
|
Reference in New Issue
Block a user