mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[ddp] refactor ColoDDP and ZeroDDP (#1146)
* ColoDDP supports overwriting default process group * rename ColoDDPV2 to ZeroDDP * add docstr for ZeroDDP * polish docstr
This commit is contained in:
@@ -8,7 +8,7 @@ from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from functools import partial
|
||||
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
|
||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
import torch.distributed as dist
|
||||
@@ -30,11 +30,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ColoDDPV2:
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ColoDDPV2(module, gemini_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
@@ -71,8 +71,8 @@ def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
set_seed(dist.get_rank())
|
||||
run_fwd_bwd(ColoDDP, init_ddp)
|
||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=False))
|
||||
run_fwd_bwd(ColoDDPV2, partial(init_ddpv2, use_chunk=True))
|
||||
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=False))
|
||||
run_fwd_bwd(ZeroDDP, partial(init_ddpv2, use_chunk=True))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -9,7 +9,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ColoDDPV2, ColoDDP
|
||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
from collections import OrderedDict
|
||||
@@ -25,11 +25,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ColoDDPV2:
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ColoDDPV2(module, gemini_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
|
||||
|
||||
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
|
@@ -13,7 +13,7 @@ from functools import partial
|
||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
@@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
|
Reference in New Issue
Block a user