From 61545fcfee52f28e71ad8a0393128ea79323fd32 Mon Sep 17 00:00:00 2001 From: Wenhao Chen Date: Mon, 1 Apr 2024 15:58:02 +0800 Subject: [PATCH] feat: add `sub_dp_size` in plugin --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eba7d1c1f..bfe1f183d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -633,6 +633,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, partition_grad: bool = False, # stage 2 flag + sub_dp_size: int = 1, cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp @@ -663,6 +664,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): communication_dtype=communication_dtype, overlap_communication=overlap_communication, partition_grad=partition_grad, + sub_dp_size=sub_dp_size, cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, @@ -964,6 +966,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_as_bucket_view: bool = False, static_graph: bool = False, zero_bucket_size_in_m: int = 12, + sub_dp_size: int = 1, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, @@ -1070,6 +1073,7 @@ class HybridParallelPlugin(PipelinePluginBase): reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, + sub_dp_size=sub_dp_size, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision],