mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-06 10:34:23 +00:00
[ddp] refactor ColoDDP and ZeroDDP (#1146)
* ColoDDP supports overwriting default process group * rename ColoDDPV2 to ZeroDDP * add docstr for ZeroDDP * polish docstr
This commit is contained in:
parent
0e4e62d30d
commit
8106d7b8c7
@ -1,3 +1,3 @@
|
|||||||
from .data_parallel import ColoDDP, ColoDDPV2
|
from .data_parallel import ColoDDP, ZeroDDP
|
||||||
|
|
||||||
__all__ = ['ColoDDP', 'ColoDDPV2']
|
__all__ = ['ColoDDP', 'ZeroDDP']
|
||||||
|
@ -8,7 +8,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|||||||
from colossalai.tensor.chunk import TensorState, Chunk
|
from colossalai.tensor.chunk import TensorState, Chunk
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Dict, Iterable, List
|
from typing import Dict, Iterable, List, Optional
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
@ -38,12 +38,37 @@ def _cast_float(args, dtype: torch.dtype):
|
|||||||
|
|
||||||
|
|
||||||
class ColoDDP(torch.nn.Module):
|
class ColoDDP(torch.nn.Module):
|
||||||
|
"""Distributed data parallel for ColoTensor. Nested ColoDDP is not supported now.
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module) -> None:
|
Example::
|
||||||
|
>>> from colossalai.core import global_context as gpc
|
||||||
|
>>> from colossalai.context import ParallelMode
|
||||||
|
>>> model = torch.nn.Linear(20, 1)
|
||||||
|
>>> model = ColoDDP(model)
|
||||||
|
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
|
||||||
|
>>> logits = model(x)
|
||||||
|
>>> loss = criterion(logits, labels)
|
||||||
|
>>> model.backward(loss)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Module to apply DDP.
|
||||||
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||||
|
If it's None, the default data parallel group will be used. Defaults to None.
|
||||||
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||||
|
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
|
assert not isinstance(module, ColoDDP)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||||
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||||
|
self.dp_world_size = self.process_group.size()
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if getattr(p, '_ddp_to_ignore', False):
|
if getattr(p, '_ddp_to_ignore', False):
|
||||||
continue
|
continue
|
||||||
@ -77,8 +102,7 @@ class ColoDDP(torch.nn.Module):
|
|||||||
grad = grad / self.dp_world_size
|
grad = grad / self.dp_world_size
|
||||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
group = gpc.get_group(ParallelMode.DATA)
|
dist.all_reduce(grad, group=self.process_group)
|
||||||
dist.all_reduce(grad, group=group)
|
|
||||||
ColoDDP._save_grad(p, grad)
|
ColoDDP._save_grad(p, grad)
|
||||||
grad.record_stream(self.comm_stream)
|
grad.record_stream(self.comm_stream)
|
||||||
else:
|
else:
|
||||||
@ -86,8 +110,7 @@ class ColoDDP(torch.nn.Module):
|
|||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
else:
|
else:
|
||||||
group = gpc.get_cpu_group(ParallelMode.DATA)
|
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||||
dist.all_reduce(grad, group=group)
|
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -136,7 +159,27 @@ class ColoDDP(torch.nn.Module):
|
|||||||
return self.module.load_state_dict(state_dict, strict)
|
return self.module.load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
|
||||||
class ColoDDPV2(ColoDDP):
|
class ZeroDDP(ColoDDP):
|
||||||
|
"""ZeRO-DP for ColoTensor. Nested ZeroDDP is not supported now.
|
||||||
|
We can configure chunk and gemini via ChunkManager and GeminiManager respectively.
|
||||||
|
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
>>> model = torch.nn.Linear(20, 1)
|
||||||
|
>>> placement_policy = 'cuda'
|
||||||
|
>>> chunk_size = ChunkManager.search_chunk_size(model, search_range, n_grids) if use_chunk else None
|
||||||
|
>>> chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero, init_device=GeminiManager.get_default_device(placement_policy))
|
||||||
|
>>> gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
|
>>> model = ZeroDDP(model, gemini_manager)
|
||||||
|
>>> logits = model(x)
|
||||||
|
>>> loss = criterion(logits, labels)
|
||||||
|
>>> model.backward(loss)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||||
|
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||||
|
For more details, see the API reference of ``GeminiManager``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||||
super().__init__(module.half())
|
super().__init__(module.half())
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDPV2
|
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -19,7 +19,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
module: ColoDDPV2,
|
module: ZeroDDP,
|
||||||
gpu_margin_mem_ratio: float = 0.0,
|
gpu_margin_mem_ratio: float = 0.0,
|
||||||
initial_scale: float = 2**32,
|
initial_scale: float = 2**32,
|
||||||
min_scale: float = 1,
|
min_scale: float = 1,
|
||||||
@ -29,7 +29,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||||||
hysteresis: int = 2,
|
hysteresis: int = 2,
|
||||||
max_scale: float = 2**32):
|
max_scale: float = 2**32):
|
||||||
super().__init__(optim)
|
super().__init__(optim)
|
||||||
assert isinstance(module, ColoDDPV2)
|
assert isinstance(module, ZeroDDP)
|
||||||
self.module = module
|
self.module = module
|
||||||
self.gemini_manager = module.gemini_manager
|
self.gemini_manager = module.gemini_manager
|
||||||
self.chunk_manager = self.gemini_manager.chunk_manager
|
self.chunk_manager = self.gemini_manager.chunk_manager
|
||||||
|
@ -8,7 +8,7 @@ from colossalai.utils import free_port
|
|||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.tensor import ChunkManager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
|
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -30,11 +30,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
|||||||
return ColoDDP(module)
|
return ColoDDP(module)
|
||||||
|
|
||||||
|
|
||||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2:
|
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
|
||||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size)
|
chunk_manager = ChunkManager(chunk_size)
|
||||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||||
return ColoDDPV2(module, gemini_manager)
|
return ZeroDDP(module, gemini_manager)
|
||||||
|
|
||||||
|
|
||||||
class Net(torch.nn.Module):
|
class Net(torch.nn.Module):
|
||||||
@ -71,8 +71,8 @@ def run_dist(rank, world_size, port):
|
|||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
set_seed(dist.get_rank())
|
set_seed(dist.get_rank())
|
||||||
run_fwd_bwd(ColoDDP, init_ddp)
|
run_fwd_bwd(ColoDDP, init_ddp)
|
||||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False))
|
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
|
||||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True))
|
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
|||||||
from colossalai.tensor import ChunkManager
|
from colossalai.tensor import ChunkManager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from colossalai.nn.parallel import ColoDDPV2, ColoDDP
|
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
@ -25,11 +25,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
|||||||
return ColoDDP(module)
|
return ColoDDP(module)
|
||||||
|
|
||||||
|
|
||||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2:
|
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
|
||||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||||
return ColoDDPV2(module, gemini_manager)
|
return ZeroDDP(module, gemini_manager)
|
||||||
|
|
||||||
|
|
||||||
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||||
|
@ -13,7 +13,7 @@ from functools import partial
|
|||||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from colossalai.nn.parallel import ColoDDPV2
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.zero import ZeroOptimizer
|
from colossalai.zero import ZeroOptimizer
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
|||||||
enable_distributed_storage=use_zero,
|
enable_distributed_storage=use_zero,
|
||||||
init_device=GeminiManager.get_default_device(placement_policy))
|
init_device=GeminiManager.get_default_device(placement_policy))
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ColoDDPV2(model, gemini_manager)
|
model = ZeroDDP(model, gemini_manager)
|
||||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user