fix: modify model config and add Qwen2RMSNorm

This commit is contained in:
Wenhao Chen
2024-03-04 11:22:34 +08:00
committed by アマデウス
parent 5c2a47a667
commit 5512bdf1fc
2 changed files with 88 additions and 1 deletions

View File

@@ -276,7 +276,7 @@ class FusedRMSNorm(BaseLayerNorm):
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
if module.__class__.__name__ in ["LlamaRMSNorm", "Qwen2RMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True