mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -8,7 +8,6 @@ from torch import Tensor
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.whisper import (
|
||||
WhisperPipelineForwards,
|
||||
@@ -19,13 +18,14 @@ from ..modeling.whisper import (
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
|
||||
'WhisperForAudioClassificationPolicy'
|
||||
"WhisperPolicy",
|
||||
"WhisperModelPolicy",
|
||||
"WhisperForConditionalGenerationPolicy",
|
||||
"WhisperForAudioClassificationPolicy",
|
||||
]
|
||||
|
||||
|
||||
class WhisperPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -55,179 +55,197 @@ class WhisperPolicy(Policy):
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag."
|
||||
)
|
||||
|
||||
#TODO using the jit fused add_and_dropout affect the accuracy
|
||||
# TODO using the jit fused add_and_dropout affect the accuracy
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.shard_config.enable_jit_fused = False
|
||||
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused operator flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attn.embed_dim":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads":
|
||||
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.encoder_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[WhisperDecoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attn.embed_dim":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads":
|
||||
self.model.config.decoder_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"encoder_attn.embed_dim":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"encoder_attn.num_heads":
|
||||
self.model.config.encoder_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
])
|
||||
policy[WhisperDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.decoder_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"encoder_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"encoder_attn.num_heads": self.model.config.encoder_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.q_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.k_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.v_proj",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="encoder_attn.out_proj",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
policy[WhisperDecoder] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
])
|
||||
policy[WhisperDecoder] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
# Handle encoder layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
|
||||
# Handle decoder layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
|
||||
# handle encoder layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperEncoder,
|
||||
)
|
||||
|
||||
# handle decoder layer
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
|
||||
# enable flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_whisper_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperAttention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_whisper_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperAttention,
|
||||
)
|
||||
|
||||
# use jit fused operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_whisper_decoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_whisper_encoder_layer_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_whisper_decoder_layer_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperDecoderLayer,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_whisper_encoder_layer_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=WhisperEncoderLayer,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -236,10 +254,13 @@ class WhisperPolicy(Policy):
|
||||
|
||||
# optimize for tensor parallelism
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}),
|
||||
policy=base_policy,
|
||||
target_key=WhisperForConditionalGeneration)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=WhisperForConditionalGeneration,
|
||||
)
|
||||
|
||||
return base_policy
|
||||
|
||||
@@ -247,8 +268,9 @@ class WhisperPolicy(Policy):
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
|
||||
num_stages: int) -> Tuple[List[int], int]:
|
||||
def distribute_whisper_layers(
|
||||
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
|
||||
) -> Tuple[List[int], int]:
|
||||
"""
|
||||
Distribute whisper layers into stages when pipeline parallel is used.
|
||||
Return the layer distribution as a list and the starting stage of decoder.
|
||||
@@ -281,8 +303,9 @@ class WhisperPolicy(Policy):
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
@staticmethod
|
||||
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
|
||||
decoder_starting_stage: int) -> Tuple[bool, int, int]:
|
||||
def get_whisper_stage_index(
|
||||
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
@@ -293,13 +316,12 @@ class WhisperPolicy(Policy):
|
||||
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
|
||||
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
if self.model.__class__.__name__ == 'WhisperModel':
|
||||
if self.model.__class__.__name__ == "WhisperModel":
|
||||
model = self.model
|
||||
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||
elif self.model.__class__.__name__ == "WhisperForConditionalGeneration":
|
||||
model = self.model.model
|
||||
else:
|
||||
model = None
|
||||
@@ -320,9 +342,11 @@ class WhisperPolicy(Policy):
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
|
||||
decoder_starting_stage)
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(
|
||||
layers_per_stage, stage_manager.stage, decoder_starting_stage
|
||||
)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in whisper's encoder
|
||||
@@ -347,14 +371,14 @@ class WhisperPolicy(Policy):
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
if self.model.__class__.__name__ == 'WhisperModel':
|
||||
if self.model.__class__.__name__ == "WhisperModel":
|
||||
model = self.model
|
||||
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
|
||||
elif self.model.__class__.__name__ == "WhisperForConditionalGeneration":
|
||||
model = self.model.model
|
||||
else:
|
||||
model = None
|
||||
@@ -373,34 +397,37 @@ class WhisperPolicy(Policy):
|
||||
num_decoder_layers = 0
|
||||
|
||||
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
|
||||
decoder_starting_stage)
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
stage_index = WhisperPolicy.get_whisper_stage_index(
|
||||
layers_per_stage, stage_manager.stage, decoder_starting_stage
|
||||
)
|
||||
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
|
||||
# WhisperModel
|
||||
class WhisperModelPolicy(WhisperPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperModel
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=WhisperModel,
|
||||
new_forward=WhisperPipelineForwards.whisper_model_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -414,19 +441,21 @@ class WhisperModelPolicy(WhisperPolicy):
|
||||
|
||||
# WhisperForConditionalGeneration
|
||||
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
|
||||
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=WhisperForConditionalGeneration,
|
||||
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
@@ -457,8 +486,9 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
|
||||
stage_manager.num_stages)
|
||||
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
shared_params = []
|
||||
shared_embedding = {}
|
||||
if id(module.proj_out) == id(model.decoder.embed_tokens):
|
||||
@@ -472,7 +502,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
||||
|
||||
# WhisperForAudioClassification
|
||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -481,12 +510,15 @@ class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperForAudioClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
|
||||
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=WhisperForAudioClassification,
|
||||
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
|
Reference in New Issue
Block a user