mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 11:44:03 +00:00
[shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
This commit is contained in:
@@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy):
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
|
||||
SubModuleReplacementDescription(suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedRMSNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
|
||||
target_module=col_nn.FusedRMSNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
return policy
|
||||
|
@@ -23,7 +23,7 @@ class ViTPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
policy = {}
|
||||
|
||||
|
Reference in New Issue
Block a user