[shardformer] fix chatglm implementation (#5644)

* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
This commit is contained in:
Hongxin Liu
2024-04-25 14:41:17 +08:00
committed by GitHub
parent 5d88ef1aaf
commit bbb2c21f16
11 changed files with 193 additions and 117 deletions

View File

@@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
)
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
# try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
rmsnorm.weight = module.weight