mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 20:46:00 +00:00
[shardformer] refactored some doc and api (#4137)
* [shardformer] refactored some doc and api * polish code
This commit is contained in:
@@ -98,7 +98,6 @@ class BloomPolicy(Policy):
|
||||
"self_attention.num_heads":
|
||||
self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
@@ -125,7 +124,6 @@ class BloomPolicy(Policy):
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
@@ -160,10 +158,6 @@ class BloomPolicy(Policy):
|
||||
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@@ -180,13 +174,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
@@ -213,13 +204,10 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForSequenceClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="score",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
@@ -233,17 +221,14 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
BloomForTokenClassification:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="classifier",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
])
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
])
|
||||
}
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
Reference in New Issue
Block a user