[zero]support no_sync method for zero1 plugin (#4138)

* support no sync for zero1 plugin

* polish

* polish
This commit is contained in:
LuGY
2023-07-04 12:00:33 +08:00
committed by Hongxin Liu
parent c6ab96983a
commit 79cf1b5f33
8 changed files with 45 additions and 49 deletions

View File

@@ -9,6 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer
@@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
@@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
no_sync = number == 0
with conditional_context(zero1_optimizer.no_sync(), no_sync):
zero1_optimizer.backward(zero1_output.sum().float())
with conditional_context(zero2_optimizer.no_sync(), no_sync):
zero2_optimizer.backward(zero2_output.sum().float())
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
@@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)
# torch-ddp forward
no_sync = number == 0
# zero1 fwd and bwd
with conditional_context(zero_optimizer.no_sync(), no_sync):
zero_output = zero_model(cur_data)
zero_optimizer.backward(zero_output.sum().float())
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
# torch-ddp fwd and bwd
with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
@@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)
fwd_bwd_func(0, input_data1, True)