[zero] fix gradient clipping in hybrid parallelism (#2521)

* [zero] fix gradient clipping in hybrid parallelism

* [testing] change model name to avoid pytest warning

* [hotfix] fix unit testing
This commit is contained in:
HELSON
2023-01-29 15:09:57 +08:00
committed by GitHub
parent fd8d19a6e7
commit 077a5cdde4
6 changed files with 45 additions and 26 deletions

View File

@@ -15,10 +15,10 @@ from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@@ -33,7 +33,7 @@ def exam_zero_1_2_grad_acc():
seed_all(2009)
# create model
zero1_model = TestModel().cuda()
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
@@ -89,7 +89,7 @@ def exam_zero_1_grad_acc():
seed_all(2008)
# create models
zero_model = TestModel()
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)

View File

@@ -14,10 +14,10 @@ from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@@ -55,7 +55,7 @@ def exam_zero_1_2():
seed_all(2001)
# create model
zero1_model = TestModel().cuda()
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
@@ -111,7 +111,7 @@ def exam_zero_1_torch_ddp():
seed_all(1453)
# create models
zero_model = TestModel()
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()

View File

@@ -13,10 +13,10 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import LowLevelZeroOptimizer
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@@ -28,9 +28,9 @@ class TestModel(nn.Module):
def exam_zero_init():
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
model1 = TestModel().cuda()
model1 = MlpModel().cuda()
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
model2 = TestModel()
model2 = MlpModel()
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))

View File

@@ -20,10 +20,10 @@ def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
class TestModel(nn.Module):
class MlpModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(32, 128)
self.act = nn.GELU()
self.linear2 = nn.Linear(128, 32)
@@ -42,8 +42,8 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
tp_pg = ProcessGroup(tp_degree=2)
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
hybrid_model = TestModel()
torch_model = TestModel().cuda()
hybrid_model = MlpModel()
torch_model = MlpModel().cuda()
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
pt.data.copy_(ph.data)
@@ -55,10 +55,11 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
split_param_col_tp1d(param, tp_pg)
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1)
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1)
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2)
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
initial_scale=1,
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
@@ -71,6 +72,7 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
assert_close(torch_loss, hybrid_loss)
torch_loss.backward()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
hybrid_optim.backward(hybrid_loss)
hybrid_optim.sync_grad()