mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[zero] polish low level optimizer (#2473)
This commit is contained in:
@@ -35,18 +35,15 @@ def exam_zero_1_2_grad_acc():
|
||||
# create model
|
||||
zero1_model = TestModel().cuda()
|
||||
zero2_model = copy.deepcopy(zero1_model)
|
||||
pg = ProcessGroup()
|
||||
# create optimizer
|
||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
@@ -86,7 +83,7 @@ def exam_zero_1_2_grad_acc():
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_acc(use_pg=True):
|
||||
def exam_zero_1_grad_acc():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
grad_scale = 32
|
||||
seed_all(2008)
|
||||
@@ -105,9 +102,7 @@ def exam_zero_1_grad_acc(use_pg=True):
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
pg = ProcessGroup() if use_pg else None #ProcessGroup()
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=False,
|
||||
initial_scale=grad_scale,
|
||||
reduce_bucket_size=262144,
|
||||
@@ -158,9 +153,8 @@ def exam_zero_1_grad_acc(use_pg=True):
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_grad_acc(True)
|
||||
exam_zero_1_grad_acc(False)
|
||||
# exam_zero_1_2_grad_acc()
|
||||
exam_zero_1_grad_acc()
|
||||
exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
@@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
@@ -59,17 +58,14 @@ def exam_zero_1_2():
|
||||
zero1_model = TestModel().cuda()
|
||||
zero2_model = copy.deepcopy(zero1_model)
|
||||
|
||||
pg = ProcessGroup()
|
||||
# create optimizer
|
||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=128)
|
||||
@@ -119,7 +115,7 @@ def exam_zero_1_torch_ddp():
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
# torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
||||
torch_model = torch_model.cuda()
|
||||
|
||||
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
@@ -131,9 +127,7 @@ def exam_zero_1_torch_ddp():
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
pg = ProcessGroup()
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
pg=pg,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
reduce_bucket_size=262144)
|
||||
|
Reference in New Issue
Block a user