diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index eba7d1c1f..bfe1f183d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -633,6 +633,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, partition_grad: bool = False, # stage 2 flag + sub_dp_size: int = 1, cpu_offload: bool = False, # cpu offload dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm tp_process_group: Optional[ProcessGroup] = None, # if using tp @@ -663,6 +664,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): communication_dtype=communication_dtype, overlap_communication=overlap_communication, partition_grad=partition_grad, + sub_dp_size=sub_dp_size, cpu_offload=cpu_offload, dp_process_group=dp_process_group, forced_dtype=forced_dtype, @@ -964,6 +966,7 @@ class HybridParallelPlugin(PipelinePluginBase): gradient_as_bucket_view: bool = False, static_graph: bool = False, zero_bucket_size_in_m: int = 12, + sub_dp_size: int = 1, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, overlap_communication: bool = True, @@ -1070,6 +1073,7 @@ class HybridParallelPlugin(PipelinePluginBase): reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024, communication_dtype=communication_dtype, overlap_communication=overlap_communication, + sub_dp_size=sub_dp_size, cpu_offload=cpu_offload, partition_grad=(self.zero_stage == 2), forced_dtype=PRECISION_TORCH_TYPE[precision],