mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[ddp] refactor ColoDDP and ZeroDDP (#1146)
* ColoDDP supports overwriting default process group * rename ColoDDPV2 to ZeroDDP * add docstr for ZeroDDP * polish docstr
This commit is contained in:
@@ -13,7 +13,7 @@ from functools import partial
|
||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
@@ -87,7 +87,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
|
Reference in New Issue
Block a user