mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Optimizer] Remove useless ColoOptimizer (#1312)
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from _utils import split_param_row_tp1d, split_param_col_tp1d
|
||||
@@ -33,7 +33,8 @@ def run_1d_hybrid_tp(model_name):
|
||||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
optimizer_torch = ColoOptimizer(model_torch.named_parameters(), torch.optim.SGD, lr=0.1)
|
||||
|
||||
optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
@@ -80,7 +81,7 @@ def run_1d_hybrid_tp(model_name):
|
||||
if rank == 0:
|
||||
model_torch.train()
|
||||
|
||||
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
@@ -170,7 +171,7 @@ def test_colo_optimizer():
|
||||
with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
colo_optimizer.zero_grad()
|
||||
data = data.to(get_current_device())
|
||||
|
@@ -18,7 +18,7 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
@@ -117,7 +117,7 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
model_reload = model_reload.cuda()
|
||||
model_reload.train()
|
||||
|
||||
colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
|
||||
colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.named_parameters(), r=0.1))
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
|
Reference in New Issue
Block a user