[ddp] refactor ColoDDP and ZeroDDP (#1146)

* ColoDDP supports overwriting default process group

* rename ColoDDPV2 to ZeroDDP

* add docstr for ZeroDDP

* polish docstr
This commit is contained in:
ver217 2022-06-21 16:35:23 +08:00 committed by GitHub
parent 0e4e62d30d
commit 8106d7b8c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 23 deletions

View File

@ -1,3 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2
from .data_parallel import ColoDDP, ZeroDDP
__all__ = ['ColoDDP', 'ColoDDPV2']
__all__ = ['ColoDDP', 'ZeroDDP']

View File

@ -8,7 +8,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Optional
from colossalai.logging import get_dist_logger
from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter
@ -38,12 +38,37 @@ def _cast_float(args, dtype: torch.dtype):
class ColoDDP(torch.nn.Module):
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
def __init__(self, module: torch.nn.Module) -> None:
Example::
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> model = ColoDDP(model)
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: Optional[dist.ProcessGroup] = None,
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None:
assert not isinstance(module, ColoDDP)
super().__init__()
self.module = module
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
self.dp_world_size = self.process_group.size()
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
@ -77,8 +102,7 @@ class ColoDDP(torch.nn.Module):
grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
group = gpc.get_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
dist.all_reduce(grad, group=self.process_group)
ColoDDP._save_grad(p, grad)
grad.record_stream(self.comm_stream)
else:
@ -86,8 +110,7 @@ class ColoDDP(torch.nn.Module):
return empty_grad
else:
group = gpc.get_cpu_group(ParallelMode.DATA)
dist.all_reduce(grad, group=group)
dist.all_reduce(grad, group=self.cpu_process_group)
return grad
@staticmethod
@ -136,7 +159,27 @@ class ColoDDP(torch.nn.Module):
return self.module.load_state_dict(state_dict, strict)
class ColoDDPV2(ColoDDP):
class ZeroDDP(ColoDDP):
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Example::
>>> model = torch.nn.Linear(20, 1)
>>> placement_policy = 'cuda'
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
>>> model = ZeroDDP(model, gemini_manager)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
"""
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half())

View File

@ -2,7 +2,7 @@ import torch
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ZeroDDP
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
@ -19,7 +19,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def __init__(self,
optim: Optimizer,
module: ColoDDPV2,
module: ZeroDDP,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -29,7 +29,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
hysteresis: int = 2,
max_scale: float = 2**32):
super().__init__(optim)
assert isinstance(module, ColoDDPV2)
assert isinstance(module, ZeroDDP)
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager = self.gemini_manager.chunk_manager

View File

@ -8,7 +8,7 @@ from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from functools import partial
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
import torch.distributed as dist
@ -30,11 +30,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ColoDDPV2(module, gemini_manager)
return ZeroDDP(module, gemini_manager)
class Net(torch.nn.Module):
@ -71,8 +71,8 @@ def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
set_seed(dist.get_rank())
run_fwd_bwd(ColoDDP, init_ddp)
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False))
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True))
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
@pytest.mark.dist

View File

@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ColoDDPV2, ColoDDP
from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
from collections import OrderedDict
@ -25,11 +25,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ColoDDPV2(module, gemini_manager)
return ZeroDDP(module, gemini_manager)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):

View File

@ -13,7 +13,7 @@ from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ColoDDPV2(model, gemini_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)