mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[Gemini] add GeminiAdamOptimizer (#1960)
This commit is contained in:
@@ -7,16 +7,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.gemini.chunk import search_chunk_configuration
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.parallel import GeminiDDP, ZeroDDP
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||
@@ -96,19 +94,23 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
|
||||
init_device = torch.device('cpu')
|
||||
else:
|
||||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
|
||||
# The same as the following 3 lines
|
||||
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
# model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
|
||||
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||
# The same as the following 2 lines
|
||||
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
# zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
print(chunk_manager)
|
||||
check_param(model, torch_model, pg)
|
||||
|
||||
model.eval()
|
||||
|
Reference in New Issue
Block a user