[shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit

This commit is contained in:
klhhhhh
2023-07-20 19:14:04 +08:00
committed by Hongxin Liu
parent 4da05052f4
commit 8120eca0c0
4 changed files with 37 additions and 4 deletions

View File

@@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy):
policy=policy,
target_key=ChatGLMModel)
else:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=GLMBlock)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedRMSNorm)
],
policy=policy,
target_key=ChatGLMModel)
return policy
def postprocess(self):
return self.model
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
def module_policy(self):
policy = super().module_policy()
return policy

View File

@@ -23,7 +23,7 @@ class ViTPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
policy = {}