mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch
This commit is contained in:
@@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
||||
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
||||
normalized_shape = module.weight.shape[0]
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
else:
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
|
||||
# try to get normalized_shape, eps, elementwise_affine from the module
|
||||
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||
|
||||
rmsnorm = FusedRMSNormWithHook(
|
||||
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
)
|
||||
|
||||
rmsnorm.weight = module.weight
|
||||
|
Reference in New Issue
Block a user