mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,6 @@ from torch.nn import Module
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
@@ -19,14 +18,20 @@ from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'BertPolicy', 'BertModelPolicy', 'BertForPreTrainingPolicy', 'BertLMdHeadModelPolicy', 'BertForMaskedLMPolicy',
|
||||
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
|
||||
'BertForMultipleChoicePolicy', 'BertForQuestionAnsweringPolicy'
|
||||
"BertPolicy",
|
||||
"BertModelPolicy",
|
||||
"BertForPreTrainingPolicy",
|
||||
"BertLMdHeadModelPolicy",
|
||||
"BertForMaskedLMPolicy",
|
||||
"BertForNextSentencePredictionPolicy",
|
||||
"BertForSequenceClassificationPolicy",
|
||||
"BertForTokenClassificationPolicy",
|
||||
"BertForMultipleChoicePolicy",
|
||||
"BertForQuestionAnsweringPolicy",
|
||||
]
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -58,136 +63,140 @@ class BertPolicy(Policy):
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"attention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size":
|
||||
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[BertLayer] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"attention.self.all_head_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.all_head_size": self.model.config.hidden_size
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"attention.self.num_attention_heads": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"crossattention.self.num_attention_heads": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.query",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.key",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.self.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
kwargs={"seq_parallel": use_sequence_parallel},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[BertEmbeddings] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
policy[BertEmbeddings] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if use_sequence_parallel:
|
||||
self.append_or_create_method_replacement(
|
||||
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
|
||||
description={"forward": bert_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=BertModel)
|
||||
target_key=BertModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer,
|
||||
)
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)],
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings)
|
||||
target_key=BertEmbeddings,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_bert_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfAttention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_bert_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfAttention,
|
||||
)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bert_self_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfOutput)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_bert_output_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertOutput)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_bert_self_output_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSelfOutput,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_bert_output_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertOutput,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -196,31 +205,37 @@ class BertPolicy(Policy):
|
||||
|
||||
# optimize for tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead,
|
||||
)
|
||||
|
||||
# optimize with fused normalization
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle bert lm prediction head
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="transform.LayerNorm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead,
|
||||
)
|
||||
return base_policy
|
||||
|
||||
def add_lm_prediction_policy(self, base_policy):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
|
||||
method_replacement = {
|
||||
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
|
||||
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
|
||||
"_save_to_state_dict": col_nn.ParallelModule._save_to_state_dict,
|
||||
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead
|
||||
)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
@@ -228,7 +243,7 @@ class BertPolicy(Policy):
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if self.pipeline_stage_manager:
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == "BertModel":
|
||||
@@ -239,15 +254,13 @@ class BertPolicy(Policy):
|
||||
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config)
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=policy,
|
||||
target_key=model_cls)
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=model_cls
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@@ -255,7 +268,7 @@ class BertPolicy(Policy):
|
||||
"""Get pipeline layers for current stage."""
|
||||
assert self.pipeline_stage_manager is not None
|
||||
|
||||
if self.model.__class__.__name__ == 'BertModel':
|
||||
if self.model.__class__.__name__ == "BertModel":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.bert
|
||||
@@ -275,17 +288,17 @@ class BertPolicy(Policy):
|
||||
|
||||
# BertModel
|
||||
class BertModelPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertModel,
|
||||
new_forward=BertPipelineForwards.bert_model_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
@@ -300,7 +313,6 @@ class BertModelPolicy(BertPolicy):
|
||||
|
||||
# BertForPreTraining
|
||||
class BertForPreTrainingPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -309,10 +321,13 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForPreTraining,
|
||||
new_forward=BertPipelineForwards.bert_for_pretraining_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForPreTraining,
|
||||
new_forward=BertPipelineForwards.bert_for_pretraining_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
@@ -329,16 +344,17 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
|
||||
# tie weights
|
||||
return [{
|
||||
0: model.bert.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: model.bert.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -347,10 +363,11 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertLMHeadModel,
|
||||
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
@@ -368,16 +385,17 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -386,10 +404,11 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForMaskedLM,
|
||||
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
@@ -407,16 +426,17 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return [
|
||||
{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight,
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -427,19 +447,22 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForSequenceClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
BertForSequenceClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForSequenceClassification,
|
||||
new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForSequenceClassification,
|
||||
new_forward=BertPipelineForwards.bert_for_sequence_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -461,7 +484,6 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
# BertForTokenClassification
|
||||
class BertForTokenClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -472,19 +494,22 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForTokenClassification:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
BertForTokenClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForTokenClassification,
|
||||
new_forward=BertPipelineForwards.bert_for_token_classification_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForTokenClassification,
|
||||
new_forward=BertPipelineForwards.bert_for_token_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -506,17 +531,19 @@ class BertForTokenClassificationPolicy(BertPolicy):
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForNextSentencePrediction,
|
||||
new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForNextSentencePrediction,
|
||||
new_forward=BertPipelineForwards.bert_for_next_sentence_prediction_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -537,7 +564,6 @@ class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
# BertForMultipleChoice
|
||||
class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -548,19 +574,22 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
BertForMultipleChoice:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
BertForMultipleChoice: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForMultipleChoice,
|
||||
new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForMultipleChoice,
|
||||
new_forward=BertPipelineForwards.bert_for_multiple_choice_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -581,17 +610,19 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
|
||||
|
||||
class BertForQuestionAnsweringPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(model_cls=BertForQuestionAnswering,
|
||||
new_forward=BertPipelineForwards.bert_for_question_answering_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=BertForQuestionAnswering,
|
||||
new_forward=BertPipelineForwards.bert_for_question_answering_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
|
Reference in New Issue
Block a user