mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
[shardformer/fix overlap bug] fix overlap bug, add overlap as an option in shardco… (#4516)
* fix overlap bug and support bert, add overlap as an option in shardconfig * support overlap for chatglm and bloom
This commit is contained in:
@@ -20,6 +20,8 @@ class ShardConfig:
|
||||
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
|
||||
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
|
||||
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
|
||||
"""
|
||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||
@@ -29,6 +31,7 @@ class ShardConfig:
|
||||
enable_flash_attention: bool = False
|
||||
enable_jit_fused: bool = False
|
||||
enable_sequence_parallelism: bool = False
|
||||
enable_sequence_overlap: bool = False
|
||||
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
@@ -41,6 +44,11 @@ class ShardConfig:
|
||||
return self._tensor_parallel_size
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
|
||||
raise ValueError(
|
||||
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
|
||||
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
|
||||
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
|
||||
if not self.enable_tensor_parallelism:
|
||||
self._tensor_parallel_size = 1
|
||||
else:
|
||||
@@ -59,3 +67,4 @@ class ShardConfig:
|
||||
self.enable_flash_attention = True
|
||||
self.enable_jit_fused = True
|
||||
self.enable_sequence_parallelism = True
|
||||
self.enable_sequence_overlap = True
|
||||
|
Reference in New Issue
Block a user