mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[zero] add zero wrappers (#2523)
* [zero] add zero wrappers * change names * add wrapper functions to init
This commit is contained in:
@@ -9,7 +9,6 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
@@ -60,16 +59,16 @@ def exam_zero_1_2_grad_acc():
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
|
||||
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
|
||||
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
zero1_optimizer.sync_grad()
|
||||
zero2_optimizer.sync_grad()
|
||||
zero1_optimizer._sync_grad()
|
||||
zero2_optimizer._sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1)
|
||||
fwd_bwd_func(1, input_data2)
|
||||
@@ -124,7 +123,7 @@ def exam_zero_1_grad_acc():
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
|
||||
# torch-ddp backward
|
||||
torch_output.sum().backward()
|
||||
|
||||
@@ -135,7 +134,7 @@ def exam_zero_1_grad_acc():
|
||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, unscale_grad)
|
||||
|
||||
zero_optimizer.sync_grad()
|
||||
zero_optimizer._sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
|
@@ -78,16 +78,16 @@ def exam_zero_1_2():
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.mean().float())
|
||||
zero2_optimizer.backward(zero2_output.mean().float())
|
||||
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
|
||||
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
|
||||
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
zero1_optimizer.sync_grad()
|
||||
zero2_optimizer.sync_grad()
|
||||
zero1_optimizer._sync_grad()
|
||||
zero2_optimizer._sync_grad()
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
@@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp():
|
||||
half_close(zero_output, torch_output, loose=True)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
||||
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
@@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp():
|
||||
half_close(p.grad, z1p.grad, loose=True)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.sync_grad()
|
||||
zero_optimizer._sync_grad()
|
||||
zero_optimizer.step()
|
||||
|
||||
# torch ddp step
|
||||
|
@@ -74,7 +74,6 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
|
||||
torch_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
hybrid_optim.backward(hybrid_loss)
|
||||
hybrid_optim.sync_grad()
|
||||
|
||||
torch_optim.step()
|
||||
hybrid_optim.step()
|
||||
|
Reference in New Issue
Block a user