[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom
This commit is contained in:
Bin Jia
2023-08-28 17:16:40 +08:00
committed by GitHub
parent 376533a564
commit c554b7f559
7 changed files with 63 additions and 39 deletions

View File

@@ -56,6 +56,7 @@ class BertPolicy(Policy):
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
@@ -71,17 +72,26 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
@@ -99,7 +109,10 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="output.dense",