mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api * polish code
This commit is contained in:
@@ -28,7 +28,7 @@ class LlamaPolicy(Policy):
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
LlamaDecoderLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
@@ -37,7 +37,6 @@ class LlamaPolicy(Policy):
|
||||
"self_attn.num_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
@@ -70,14 +69,12 @@ class LlamaPolicy(Policy):
|
||||
],
|
||||
),
|
||||
LlamaModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
# optimization configuration
|
||||
@@ -101,9 +98,6 @@ class LlamaPolicy(Policy):
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@@ -117,13 +111,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
@@ -139,13 +130,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||
# add a new item for sequence classification
|
||||
new_item = {
|
||||
LlamaForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
Reference in New Issue
Block a user