diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py index c22b027ca..9645e95f6 100644 --- a/colossalai/nn/parallel/__init__.py +++ b/colossalai/nn/parallel/__init__.py @@ -1,3 +1,3 @@ -from .data_parallel import ColoDDP, ColoDDPV2 +from .data_parallel import ColoDDP, ZeroDDP -__all__ = ['ColoDDP', 'ColoDDPV2'] +__all__ = ['ColoDDP', 'ZeroDDP'] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 9b0e88ea8..0c0a3e33a 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -8,7 +8,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2 from colossalai.tensor.chunk import TensorState, Chunk from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.gemini.gemini_mgr import GeminiManager -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Optional from colossalai.logging import get_dist_logger from collections import OrderedDict from colossalai.tensor.colo_parameter import ColoParameter @@ -38,12 +38,37 @@ def _cast_float(args, dtype: torch.dtype): class ColoDDP(torch.nn.Module): + """Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now. - def __init__(self, module: torch.nn.Module) -> None: + Example:: + >>> from colossalai.core import global_context as gpc + >>> from colossalai.context import ParallelMode + >>> model = torch.nn.Linear(20, 1) + >>> model = ColoDDP(model) + >>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA)) + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): Module to apply DDP. + process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. + If it's None, the default data parallel group will be used. Defaults to None. + process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU. + If it's None, the default CPU data parallel group will be used. Defaults to None. + """ + + def __init__(self, + module: torch.nn.Module, + process_group: Optional[dist.ProcessGroup] = None, + cpu_process_group: Optional[dist.ProcessGroup] = None) -> None: + assert not isinstance(module, ColoDDP) super().__init__() self.module = module self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() - self.dp_world_size = gpc.get_world_size(ParallelMode.DATA) + self.process_group = process_group or gpc.get_group(ParallelMode.DATA) + self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA) + self.dp_world_size = self.process_group.size() for p in module.parameters(): if getattr(p, '_ddp_to_ignore', False): continue @@ -77,8 +102,7 @@ class ColoDDP(torch.nn.Module): grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - group = gpc.get_group(ParallelMode.DATA) - dist.all_reduce(grad, group=group) + dist.all_reduce(grad, group=self.process_group) ColoDDP._save_grad(p, grad) grad.record_stream(self.comm_stream) else: @@ -86,8 +110,7 @@ class ColoDDP(torch.nn.Module): return empty_grad else: - group = gpc.get_cpu_group(ParallelMode.DATA) - dist.all_reduce(grad, group=group) + dist.all_reduce(grad, group=self.cpu_process_group) return grad @staticmethod @@ -136,7 +159,27 @@ class ColoDDP(torch.nn.Module): return self.module.load_state_dict(state_dict, strict) -class ColoDDPV2(ColoDDP): +class ZeroDDP(ColoDDP): + """ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now. + We can configure chunk and gemini via ChunkManager and GeminiManager respectively. + For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. + + Example:: + >>> model = torch.nn.Linear(20, 1) + >>> placement_policy = 'cuda' + >>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None + >>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy)) + >>> gemini_manager = GeminiManager(placement_policy, chunk_manager) + >>> model = ZeroDDP(model, gemini_manager) + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): Module to apply ZeRO-DP. + gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. + For more details, see the API reference of ``GeminiManager``. + """ def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: super().__init__(module.half()) diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index d263da59d..113b88364 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist from enum import Enum from torch.optim import Optimizer -from colossalai.nn.parallel.data_parallel import ColoDDPV2 +from colossalai.nn.parallel.data_parallel import ZeroDDP from typing import Dict from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger @@ -19,7 +19,7 @@ class ZeroOptimizer(ColossalaiOptimizer): def __init__(self, optim: Optimizer, - module: ColoDDPV2, + module: ZeroDDP, gpu_margin_mem_ratio: float = 0.0, initial_scale: float = 2**32, min_scale: float = 1, @@ -29,7 +29,7 @@ class ZeroOptimizer(ColossalaiOptimizer): hysteresis: int = 2, max_scale: float = 2**32): super().__init__(optim) - assert isinstance(module, ColoDDPV2) + assert isinstance(module, ZeroDDP) self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager = self.gemini_manager.chunk_manager diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index fb6e8cb8e..dba47b052 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -8,7 +8,7 @@ from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ChunkManager from functools import partial -from colossalai.nn.parallel import ColoDDP, ColoDDPV2 +from colossalai.nn.parallel import ColoDDP, ZeroDDP from colossalai.gemini.gemini_mgr import GeminiManager from typing import Callable import torch.distributed as dist @@ -30,11 +30,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: return ColoDDP(module) -def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2: +def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None chunk_manager = ChunkManager(chunk_size) gemini_manager = GeminiManager('cuda', chunk_manager) - return ColoDDPV2(module, gemini_manager) + return ZeroDDP(module, gemini_manager) class Net(torch.nn.Module): @@ -71,8 +71,8 @@ def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') set_seed(dist.get_rank()) run_fwd_bwd(ColoDDP, init_ddp) - run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False)) - run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True)) + run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False)) + run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True)) @pytest.mark.dist diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index 37de68b81..782ff673a 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.tensor import ChunkManager from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ColoDDPV2, ColoDDP +from colossalai.nn.parallel import ZeroDDP, ColoDDP from colossalai.gemini.gemini_mgr import GeminiManager from typing import Callable from collections import OrderedDict @@ -25,11 +25,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: return ColoDDP(module) -def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2: +def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) gemini_manager = GeminiManager('cuda', chunk_manager) - return ColoDDPV2(module, gemini_manager) + return ZeroDDP(module, gemini_manager) def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index cdcfc4641..afd4fa1fa 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -13,7 +13,7 @@ from functools import partial from _utils import tensor_equal, set_seed, tensor_shard_equal from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.nn.parallel import ColoDDPV2 +from colossalai.nn.parallel import ZeroDDP from colossalai.nn.optimizer import HybridAdam from colossalai.zero import ZeroOptimizer from colossalai.testing import parameterize @@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy)) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ColoDDPV2(model, gemini_manager) + model = ZeroDDP(model, gemini_manager) optim = HybridAdam(model.parameters(), lr=1e-3) optim = ZeroOptimizer(optim, model, initial_scale=32)