mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] supported T5 and its variants (#4045)
This commit is contained in:
@@ -1,159 +1,173 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Embedding
|
||||
from transformers import T5ForConditionalGeneration
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5Attention,
|
||||
T5Block,
|
||||
T5DenseActDense,
|
||||
T5DenseGatedActDense,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFF,
|
||||
T5LayerSelfAttention,
|
||||
T5Model,
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
import colossalai.shardformer.layer.layers as col_nn
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
||||
|
||||
class T5ModelPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
print('config heads', config.num_heads)
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
return {
|
||||
T5Stack:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
|
||||
T5Block:
|
||||
Argument(attr_dict={}, param_funcs=[]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]),
|
||||
T5LayerCrossAttention:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5Attention:
|
||||
Argument(attr_dict={
|
||||
"d_model": config.d_model // world_size,
|
||||
"n_heads": config.num_heads // world_size,
|
||||
"inner_dim": config.num_heads * config.d_kv // world_size,
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"d_model":
|
||||
self.model.config.d_model // self.shard_config.tensor_parallel_size,
|
||||
"n_heads":
|
||||
self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
||||
"inner_dim":
|
||||
self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_funcs=[T5ModelPolicy.attn_layer]),
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="o",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="relative_attention_bias",
|
||||
target_module=Embedding1D,
|
||||
kwargs=dict(gather_output=False),
|
||||
ignore_if_not_exist=True)
|
||||
]),
|
||||
T5LayerFF:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]),
|
||||
T5DenseGatedActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_0",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi_1",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="wo",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True)),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]),
|
||||
T5DenseActDense:
|
||||
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wi",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wo",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def dense_gated_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi_0",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wi_1",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
|
||||
]
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def dense_act_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="wi",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="wo",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def attn_layer():
|
||||
return [
|
||||
Col_Layer(
|
||||
suffix="q",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="k",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Col_Layer(
|
||||
suffix="v",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
Row_Layer(
|
||||
suffix="o",
|
||||
weight="weight",
|
||||
bias="bias",
|
||||
replace_layer=col_nn.Linear1D_Row,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def dropout():
|
||||
return [Dropout_Layer(
|
||||
suffix="dropout",
|
||||
p="p",
|
||||
replace_layer=col_nn.Dropout1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def embedding():
|
||||
return [
|
||||
Embedding_Layer(
|
||||
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Embedding1D,
|
||||
gather_output=False,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import T5ForConditionalGeneration
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size):
|
||||
base_argument = T5ModelPolicy.argument_policy(config, world_size)
|
||||
argument = {
|
||||
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
new_item = {
|
||||
T5ForConditionalGeneration:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
argument.update(base_argument)
|
||||
return argument
|
||||
|
||||
@staticmethod
|
||||
def lm_head():
|
||||
return [Col_Layer(
|
||||
suffix="lm_head",
|
||||
weight="weight",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
gather_output=True,
|
||||
)]
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
|
||||
class T5EncoderModelPolicy(T5ModelPolicy):
|
||||
class T5EncoderPolicy(T5ModelPolicy):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user