[shardformer] refactored some doc and api (#4137)

* [shardformer] refactored some doc and api

* polish code
This commit is contained in:
Frank Lee
2023-07-03 15:29:11 +08:00
parent 7f9b30335b
commit 74257cb446
15 changed files with 355 additions and 490 deletions

View File

@@ -28,7 +28,7 @@ class LlamaPolicy(Policy):
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
return {
base_policy = {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
@@ -37,7 +37,6 @@ class LlamaPolicy(Policy):
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
@@ -70,14 +69,12 @@ class LlamaPolicy(Policy):
],
),
LlamaModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
}
# optimization configuration
@@ -101,9 +98,6 @@ class LlamaPolicy(Policy):
return base_policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
@@ -117,13 +111,10 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@@ -139,13 +130,10 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy