feat: add sub_dp_size in plugin

This commit is contained in:
Wenhao Chen 2024-04-01 15:58:02 +08:00 committed by アマデウス
parent 6ceaf4f1f8
commit 61545fcfee

View File

@ -633,6 +633,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
partition_grad: bool = False, # stage 2 flag
sub_dp_size: int = 1,
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
@ -663,6 +664,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
sub_dp_size=sub_dp_size,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
@ -964,6 +966,7 @@ class HybridParallelPlugin(PipelinePluginBase):
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
sub_dp_size: int = 1,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
@ -1070,6 +1073,7 @@ class HybridParallelPlugin(PipelinePluginBase):
reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
sub_dp_size=sub_dp_size,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],