[ddp] refactor ColoDDP and ZeroDDP (#1146)

* ColoDDP supports overwriting default process group

* rename ColoDDPV2 to ZeroDDP

* add docstr for ZeroDDP

* polish docstr
This commit is contained in:
ver217 2022-06-21 16:35:23 +08:00 committed by GitHub
parent 0e4e62d30d
commit 8106d7b8c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 23 deletions

View File

@ -1,3 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2 from .data_parallel import ColoDDP, ZeroDDP
__all__ = ['ColoDDP', 'ColoDDPV2'] __all__ = ['ColoDDP', 'ZeroDDP']

View File

@ -8,7 +8,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import TensorState, Chunk from colossalai.tensor.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List from typing import Dict, Iterable, List, Optional
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
@ -38,12 +38,37 @@ def _cast_float(args, dtype: torch.dtype):
class ColoDDP(torch.nn.Module): class ColoDDP(torch.nn.Module):
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
def __init__(self, module: torch.nn.Module) -> None: Example::
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> model = ColoDDP(model)
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: Optional[dist.ProcessGroup] = None,
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None:
assert not isinstance(module, ColoDDP)
super().__init__() super().__init__()
self.module = module self.module = module
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
self.dp_world_size = self.process_group.size()
for p in module.parameters(): for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False): if getattr(p, '_ddp_to_ignore', False):
continue continue
@ -77,8 +102,7 @@ class ColoDDP(torch.nn.Module):
grad = grad / self.dp_world_size grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream()) self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA) dist.all_reduce(grad, group=self.process_group)
dist.all_reduce(grad, group=group)
ColoDDP._save_grad(p, grad) ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream) grad.record_stream(self.comm_stream)
else: else:
@ -86,8 +110,7 @@ class ColoDDP(torch.nn.Module):
return empty_grad return empty_grad
else: else:
group = gpc.get_cpu_group(ParallelMode.DATA) dist.all_reduce(grad, group=self.cpu_process_group)
dist.all_reduce(grad, group=group)
return grad return grad
@staticmethod @staticmethod
@ -136,7 +159,27 @@ class ColoDDP(torch.nn.Module):
return self.module.load_state_dict(state_dict, strict) return self.module.load_state_dict(state_dict, strict)
class ColoDDPV2(ColoDDP): class ZeroDDP(ColoDDP):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example::
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
"""
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half()) super().__init__(module.half())

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from enum import Enum from enum import Enum
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.nn.parallel.data_parallel import ColoDDPV2 from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
@ -19,7 +19,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optim: Optimizer, optim: Optimizer,
module: ColoDDPV2, module: ZeroDDP,
gpu_margin_mem_ratio: float = 0.0, gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
@ -29,7 +29,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32): max_scale: float = 2**32):
super().__init__(optim) super().__init__(optim)
assert isinstance(module, ColoDDPV2) assert isinstance(module, ZeroDDP)
self.module = module self.module = module
self.gemini_manager = module.gemini_manager self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager self.chunk_manager = self.gemini_manager.chunk_manager

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager from colossalai.tensor import ChunkManager
from functools import partial from functools import partial
from colossalai.nn.parallel import ColoDDP, ColoDDPV2 from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable from typing import Callable
import torch.distributed as dist import torch.distributed as dist
@ -30,11 +30,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module) return ColoDDP(module)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size) chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ColoDDPV2(module, gemini_manager) return ZeroDDP(module, gemini_manager)
class Net(torch.nn.Module): class Net(torch.nn.Module):
@ -71,8 +71,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank()) set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp) run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False)) run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True)) run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
@pytest.mark.dist @pytest.mark.dist

View File

@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager from colossalai.tensor import ChunkManager
from functools import partial from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ColoDDPV2, ColoDDP from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable from typing import Callable
from collections import OrderedDict from collections import OrderedDict
@ -25,11 +25,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module) return ColoDDP(module)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
return ColoDDPV2(module, gemini_manager) return ZeroDDP(module, gemini_manager)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):

View File

@ -13,7 +13,7 @@ from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2 from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize from colossalai.testing import parameterize
@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ColoDDPV2(model, gemini_manager) model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=32)