[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
This commit is contained in:
flybird11111
2023-08-24 15:50:02 +08:00
committed by GitHub
parent e04436a82a
commit 3353e55c80
7 changed files with 46 additions and 21 deletions

View File

@@ -187,6 +187,9 @@ class BertPipelineForwards:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0:
@@ -1241,6 +1244,9 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
embedding_output = split_forward_gather_backward(embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
encoder_outputs = self.encoder(
embedding_output,

View File

@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@@ -35,6 +36,10 @@ class LlamaPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={

View File

@@ -104,16 +104,20 @@ class OPTPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
})
},
policy=policy,
target_key=OPTAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=OPTDecoderLayer)
return policy

View File

@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
@@ -59,6 +60,10 @@ class T5BasePolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(

View File

@@ -1,3 +1,4 @@
import warnings
from typing import Callable, Dict, List, Union
import torch.nn as nn
@@ -32,6 +33,10 @@ class ViTPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],

View File

@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple
@@ -33,7 +34,6 @@ class WhisperPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
@@ -52,6 +52,14 @@ class WhisperPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_jit_fused:
self.shard_config.enable_jit_fused = False
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
@@ -198,20 +206,11 @@ class WhisperPolicy(Policy):
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
})
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=WhisperAttention)
return policy