[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom
This commit is contained in:
Bin Jia
2023-08-28 17:16:40 +08:00
committed by GitHub
parent 376533a564
commit c554b7f559
7 changed files with 63 additions and 39 deletions

View File

@@ -45,6 +45,7 @@ class BloomPolicy(Policy):
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@@ -55,7 +56,10 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={'seq_parallel': use_sequence_parallel}),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
@@ -67,7 +71,10 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
kwargs={'seq_parallel': use_sequence_parallel}),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,