[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)

* fix overlap bug and support bert, add overlap as an option in shardconfig

* support overlap for chatglm and bloom
This commit is contained in:
Bin Jia
2023-08-28 17:16:40 +08:00
committed by GitHub
parent 376533a564
commit c554b7f559
7 changed files with 63 additions and 39 deletions

View File

@@ -20,6 +20,8 @@ class ShardConfig:
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -29,6 +31,7 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
@@ -41,6 +44,11 @@ class ShardConfig:
return self._tensor_parallel_size
def __post_init__(self):
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
raise ValueError(
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
@@ -59,3 +67,4 @@ class ShardConfig:
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True