mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-19 20:23:41 +00:00
[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)
* [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks
This commit is contained in:
parent
e04436a82a
commit
3353e55c80
@ -187,6 +187,9 @@ class BertPipelineForwards:
|
|||||||
hidden_states = split_forward_gather_backward(hidden_states,
|
hidden_states = split_forward_gather_backward(hidden_states,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group)
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = split_forward_gather_backward(
|
||||||
|
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
|
||||||
if stage_manager.is_first_stage() and idx == 0:
|
if stage_manager.is_first_stage() and idx == 0:
|
||||||
@ -1241,6 +1244,9 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
|||||||
embedding_output = split_forward_gather_backward(embedding_output,
|
embedding_output = split_forward_gather_backward(embedding_output,
|
||||||
dim=1,
|
dim=1,
|
||||||
process_group=shard_config.tensor_parallel_process_group)
|
process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states = split_forward_gather_backward(
|
||||||
|
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
@ -35,6 +36,10 @@ class LlamaPolicy(Policy):
|
|||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
|
@ -104,16 +104,20 @@ class OPTPolicy(Policy):
|
|||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
|
self.append_or_create_method_replacement(description={
|
||||||
'forward': get_opt_flash_attention_forward(),
|
'forward': get_opt_flash_attention_forward(),
|
||||||
})
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=OPTAttention)
|
||||||
|
|
||||||
# use jit fused operator
|
# use jit fused operator
|
||||||
if self.shard_config.enable_jit_fused:
|
if self.shard_config.enable_jit_fused:
|
||||||
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
|
self.append_or_create_method_replacement(description={
|
||||||
'forward': get_jit_fused_opt_decoder_layer_forward(),
|
'forward': get_jit_fused_opt_decoder_layer_forward(),
|
||||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
'dropout_add': get_jit_fused_dropout_add_func(),
|
||||||
})
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=OPTDecoderLayer)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -59,6 +60,10 @@ class T5BasePolicy(Policy):
|
|||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
|
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,6 +33,10 @@ class ViTPolicy(Policy):
|
|||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
|
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Tuple
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
@ -33,7 +34,6 @@ class WhisperPolicy(Policy):
|
|||||||
r"""
|
r"""
|
||||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||||
"""
|
"""
|
||||||
# TODO:
|
|
||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
world_size = self.shard_config.tensor_parallel_size
|
world_size = self.shard_config.tensor_parallel_size
|
||||||
if vocab_size % world_size != 0:
|
if vocab_size % world_size != 0:
|
||||||
@ -52,6 +52,14 @@ class WhisperPolicy(Policy):
|
|||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
self.shard_config.enable_sequence_parallelism = False
|
||||||
|
warnings.warn(
|
||||||
|
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||||
|
if self.shard_config.enable_jit_fused:
|
||||||
|
self.shard_config.enable_jit_fused = False
|
||||||
|
warnings.warn("Whisper dosen't support jit fused operator now, will ignore the jit fused flag.")
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
|
||||||
"self_attn.embed_dim":
|
"self_attn.embed_dim":
|
||||||
@ -198,20 +206,11 @@ class WhisperPolicy(Policy):
|
|||||||
|
|
||||||
# enable flash attention
|
# enable flash attention
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
|
self.append_or_create_method_replacement(description={
|
||||||
'forward': get_whisper_flash_attention_forward(),
|
'forward': get_whisper_flash_attention_forward(),
|
||||||
})
|
},
|
||||||
|
policy=policy,
|
||||||
# use jit fused operator
|
target_key=WhisperAttention)
|
||||||
if self.shard_config.enable_jit_fused:
|
|
||||||
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
|
|
||||||
'forward': get_jit_fused_whisper_encoder_layer_forward(),
|
|
||||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
|
||||||
})
|
|
||||||
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
|
|
||||||
'forward': get_jit_fused_whisper_decoder_layer_forward(),
|
|
||||||
'dropout_add': get_jit_fused_dropout_add_func(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
# check last hidden state & loss
|
# check last hidden state & loss
|
||||||
if stage_manager is None or stage_manager.is_last_stage():
|
if stage_manager is None or stage_manager.is_last_stage():
|
||||||
if test_config['precision'] == 'fp32':
|
if test_config['precision'] == 'fp32':
|
||||||
atol, rtol = 1e-3, 1e-3
|
atol, rtol = 2e-4, 2e-4
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
|
|
||||||
# check weights and gradients
|
# check weights and gradients
|
||||||
if test_config['precision'] == 'fp32':
|
if test_config['precision'] == 'fp32':
|
||||||
atol, rtol = 1e-3, 1e-3
|
atol, rtol = 2e-4, 2e-4
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
org_optimizer.step()
|
org_optimizer.step()
|
||||||
sharded_optimizer.step()
|
sharded_optimizer.step()
|
||||||
if test_config['precision'] == 'fp32':
|
if test_config['precision'] == 'fp32':
|
||||||
atol, rtol = 1e-3, 1e-3
|
atol, rtol = 2e-4, 2e-4
|
||||||
else:
|
else:
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
if stage_manager is None or stage_manager.is_first_stage():
|
if stage_manager is None or stage_manager.is_first_stage():
|
||||||
@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||||||
|
|
||||||
|
|
||||||
# TODO(jianghai) fix fp16
|
# TODO(jianghai) fix fp16
|
||||||
|
#TODO fix WhisperForConditionalGeneration enable jit fused operator
|
||||||
@parameterize('test_config', [{
|
@parameterize('test_config', [{
|
||||||
'tp_size': 2,
|
'tp_size': 2,
|
||||||
'pp_size': 2,
|
'pp_size': 2,
|
||||||
|
Loading…
Reference in New Issue
Block a user