[gemini] zero supports gemini (#1093)

* add placement policy

* add gemini mgr

* update mem stats collector

* update zero

* update zero optim

* fix bugs

* zero optim monitor os

* polish unit test

* polish unit test

* add assert
This commit is contained in:
ver217
2022-06-10 14:48:28 +08:00
committed by GitHub
parent 2b2dc1c86b
commit 1f894e033f
9 changed files with 366 additions and 12 deletions

View File

@@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model):
@@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero):
model = model.half()
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
print(chunk_manager)
check_param_equal(model, torch_model)

View File

@@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
from colossalai.gemini.gemini_mgr import GeminiManager
def check_param_equal(model, torch_model):
@@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero):
chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)
gemini_manager = GeminiManager('cuda', chunk_manager)
model = ColoDDPV2(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32)