[ddp] refactor ColoDDP and ZeroDDP (#1146)

* ColoDDP supports overwriting default process group

* rename ColoDDPV2 to ZeroDDP

* add docstr for ZeroDDP

* polish docstr
This commit is contained in:
ver217
2022-06-21 16:35:23 +08:00
committed by GitHub
parent 0e4e62d30d
commit 8106d7b8c7
6 changed files with 66 additions and 23 deletions

View File

@@ -13,7 +13,7 @@ from functools import partial
from _utils import tensor_equal, set_seed, tensor_shard_equal
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel import ZeroDDP
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
@@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ColoDDPV2(model, gemini_manager)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)