mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] Add overlap support for gpt2 (#4535)
* add overlap support for gpt2 * remove unused code * remove unused code
This commit is contained in:
@@ -37,7 +37,8 @@ class GPT2Policy(Policy):
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
policy = {}
|
||||
|
||||
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
|
||||
overlap = self.shard_config.enable_sequence_overlap
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
@@ -50,47 +51,54 @@ class GPT2Policy(Policy):
|
||||
),
|
||||
])
|
||||
|
||||
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
policy[GPT2Block] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
"overlap": overlap
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
kwargs={
|
||||
"seq_parallel": use_sequence_parallel,
|
||||
}),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
@@ -126,8 +134,6 @@ class GPT2Policy(Policy):
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
|
||||
suffix_list = ["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"]
|
||||
self.append_seq_parallel_to_policy(suffix_list=suffix_list, module_policy_description=policy[GPT2Block])
|
||||
|
||||
return policy
|
||||
|
||||
|
Reference in New Issue
Block a user