mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[gemini] zero supports gemini (#1093)
* add placement policy * add gemini mgr * update mem stats collector * update zero * update zero optim * fix bugs * zero optim monitor os * polish unit test * polish unit test * add assert
This commit is contained in:
@@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
@@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero):
|
||||
model = model.half()
|
||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
model = ColoDDPV2(model, chunk_manager)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||
print(chunk_manager)
|
||||
check_param_equal(model, torch_model)
|
||||
|
@@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
@@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero):
|
||||
|
||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
model = ColoDDPV2(model, chunk_manager)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
|
Reference in New Issue
Block a user