[shardformer] refactored some doc and api (#4137)

* [shardformer] refactored some doc and api

* polish code
This commit is contained in:
Frank Lee
2023-07-03 15:29:11 +08:00
parent 7f9b30335b
commit 74257cb446
15 changed files with 355 additions and 490 deletions

View File

@@ -98,7 +98,6 @@ class BloomPolicy(Policy):
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
@@ -125,7 +124,6 @@ class BloomPolicy(Policy):
ModulePolicyDescription(attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
@@ -160,10 +158,6 @@ class BloomPolicy(Policy):
return base_policy
def new_model_class(self):
# do nothing
return None
def postprocess(self):
return self.model
@@ -180,13 +174,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@@ -213,13 +204,10 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True))
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
@@ -233,17 +221,14 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)
return policy