mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[zero] support extra dp (#6123)
* [zero] support extra dp * [zero] update checkpoint * fix bugs * fix bugs
This commit is contained in:
@@ -29,6 +29,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
save_state_dict,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
@@ -333,6 +334,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -358,11 +360,16 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
extra_dp_size: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
assert precision in SUPPORTED_PRECISION, f"LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training"
|
||||
assert norm_type == 2.0, f"LowLevelZeroPlugin only supports norm_type=2.0 now"
|
||||
if extra_dp_size > 1:
|
||||
assert dist.get_world_size() % extra_dp_size == 0, "extra_dp_size should be a factor of world_size"
|
||||
inner_dp_size = dist.get_world_size() // extra_dp_size
|
||||
self.pg_mesh = ProcessGroupMesh(extra_dp_size, inner_dp_size)
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
self.zero_optim_kwargs = dict(
|
||||
@@ -383,6 +390,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
overlap_allgather=overlap_allgather,
|
||||
fp8_communication=fp8_communication,
|
||||
)
|
||||
if extra_dp_size > 1:
|
||||
self.zero_optim_kwargs["extra_dp_group"] = self.pg_mesh.get_group_along_axis(0)
|
||||
self.zero_optim_kwargs["dp_process_group"] = self.pg_mesh.get_group_along_axis(1)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger()
|
||||
|
Reference in New Issue
Block a user