[shardformer] refactored some doc and api (#4137)

* [shardformer] refactored some doc and api

* polish code
This commit is contained in:
Frank Lee
2023-07-03 15:29:11 +08:00
parent 7f9b30335b
commit 74257cb446
15 changed files with 355 additions and 490 deletions

View File

@@ -31,33 +31,28 @@ class OPTPolicy(Policy):
base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]),
OPTDecoderLayer:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="fc1",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="fc2",
target_module=Linear1D_Row,
)
]),
OPTAttention:
ModulePolicyDescription(attribute_replacement={
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="q_proj",
@@ -95,9 +90,6 @@ class OPTPolicy(Policy):
return base_policy
def new_model_class(self):
return None
def postprocess(self):
return self.model
@@ -116,13 +108,10 @@ class OPTForCausalLMPolicy(OPTPolicy):
policy = super().module_policy()
new_item = {
OPTForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)