mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch import Tensor, nn
|
||||
@@ -15,7 +15,6 @@ from colossalai.shardformer.layer import (
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from ..modeling.t5 import (
|
||||
T5PipelineForwards,
|
||||
@@ -30,7 +29,6 @@ __all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationP
|
||||
|
||||
|
||||
class T5BasePolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
@@ -65,151 +63,181 @@ class T5BasePolicy(Policy):
|
||||
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[T5Attention] = ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads":
|
||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
policy[T5LayerFF] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[T5DenseGatedActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[T5DenseActDense] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
])
|
||||
policy[T5Stack] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5LayerCrossAttention] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]
|
||||
)
|
||||
policy[T5Attention] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"d_model": self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim": self.model.config.num_heads
|
||||
* self.model.config.d_kv
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
policy[T5LayerFF] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5DenseGatedActDense] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0 ",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5DenseActDense] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention)
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="layer_norm",
|
||||
target_module=FusedRMSNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_t5_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5Attention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_t5_flash_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5Attention,
|
||||
)
|
||||
|
||||
# use jit operator
|
||||
if self.shard_config.enable_jit_fused:
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_jit_fused_T5_layer_ff_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerFF)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_T5_layer_self_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention)
|
||||
self.append_or_create_method_replacement(description={
|
||||
'forward': get_T5_layer_cross_attention_forward(),
|
||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_jit_fused_T5_layer_ff_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerFF,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_T5_layer_self_attention_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerSelfAttention,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_T5_layer_cross_attention_forward(),
|
||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=T5LayerCrossAttention,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -217,8 +245,9 @@ class T5BasePolicy(Policy):
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
|
||||
num_stages: int) -> Tuple[List[int], int]:
|
||||
def distribute_t5_layers(
|
||||
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
|
||||
) -> Tuple[List[int], int]:
|
||||
"""
|
||||
Distribute t5 layers into stages when pipeline parallel is used.
|
||||
Return the layer distribution as a list and the starting stage of decoder.
|
||||
@@ -251,8 +280,9 @@ class T5BasePolicy(Policy):
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
@staticmethod
|
||||
def get_t5_stage_index(layers_per_stage: List[int], stage: int,
|
||||
decoder_starting_stage: int) -> Tuple[bool, int, int]:
|
||||
def get_t5_stage_index(
|
||||
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
@@ -269,16 +299,18 @@ class T5BasePolicy(Policy):
|
||||
|
||||
model = self.model
|
||||
encoder = self.model.encoder
|
||||
decoder = getattr(self.model, 'decoder', None)
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage,
|
||||
decoder_starting_stage)
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(
|
||||
layers_per_stage, stage_manager.stage, decoder_starting_stage
|
||||
)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
@@ -303,47 +335,51 @@ class T5BasePolicy(Policy):
|
||||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
encoder = self.model.encoder
|
||||
decoder = getattr(self.model, 'decoder', None)
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
method_replacement = {
|
||||
'forward':
|
||||
partial(new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage)
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
||||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model,
|
||||
)
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
|
||||
|
||||
@@ -356,9 +392,9 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
|
||||
)
|
||||
|
||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
|
||||
@@ -366,7 +402,6 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -376,22 +411,26 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
],
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
|
||||
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=T5ForConditionalGeneration,
|
||||
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
@@ -404,9 +443,9 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
|
||||
)
|
||||
|
||||
shared_params = []
|
||||
shared_embedding = {}
|
||||
@@ -427,7 +466,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -437,17 +475,19 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel)
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5EncoderModel,
|
||||
new_forward=T5PipelineForwards.t5_encoder_model_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=T5EncoderModel, new_forward=T5PipelineForwards.t5_encoder_model_forward, policy=policy
|
||||
)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
|
Reference in New Issue
Block a user