mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
@@ -9,6 +9,7 @@ from torch.testing import assert_close
|
||||
import colossalai
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
@@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
|
||||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
clip_grad_norm=1.0)
|
||||
# create data
|
||||
seed_all(2021 + local_rank)
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
@@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
no_sync = number == 0
|
||||
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
if check_flag:
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
@@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=False,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
||||
@@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(cur_data)
|
||||
|
||||
# torch-ddp forward
|
||||
no_sync = number == 0
|
||||
# zero1 fwd and bwd
|
||||
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
||||
zero_output = zero_model(cur_data)
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
# torch-ddp backward
|
||||
if number < 1:
|
||||
with torch_model.no_sync():
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
else:
|
||||
# torch-ddp fwd and bwd
|
||||
with conditional_context(torch_model.no_sync(), no_sync):
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
@@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
|
||||
if check_flag:
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
|
Reference in New Issue
Block a user