feat: add sub_dp_size in plugin

This commit is contained in:
Wenhao Chen 2024-04-01 15:58:02 +08:00 committed by アマデウス
parent 6ceaf4f1f8
commit 61545fcfee

View File

@ -633,6 +633,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
sub_dp_size: int = 1,
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp tp_process_group: Optional[ProcessGroup] = None, # if using tp
@ -663,6 +664,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
communication_dtype=communication_dtype, communication_dtype=communication_dtype,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
partition_grad=partition_grad, partition_grad=partition_grad,
sub_dp_size=sub_dp_size,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
dp_process_group=dp_process_group, dp_process_group=dp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
@ -964,6 +966,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_as_bucket_view: bool = False, gradient_as_bucket_view: bool = False,
static_graph: bool = False, static_graph: bool = False,
zero_bucket_size_in_m: int = 12, zero_bucket_size_in_m: int = 12,
sub_dp_size: int = 1,
cpu_offload: bool = False, cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = True,
@ -1070,6 +1073,7 @@ class HybridParallelPlugin(PipelinePluginBase):
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype, communication_dtype=communication_dtype,
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
sub_dp_size=sub_dp_size,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2), partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision], forced_dtype=PRECISION_TORCH_TYPE[precision],