mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[zero] test gradient accumulation (#1964)
* [zero] fix memory leak for zero2 * [zero] test gradient accumulation * [zero] remove grad clip test
This commit is contained in:
167
tests/test_zero/low_level_zero/test_grad_acc.py
Normal file
167
tests/test_zero/low_level_zero/test_grad_acc.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def exam_zero_1_2_grad_acc():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2009)
|
||||
|
||||
# create model
|
||||
zero1_model = TestModel().cuda()
|
||||
zero2_model = copy.deepcopy(zero1_model)
|
||||
|
||||
# create optimizer
|
||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0)
|
||||
# create data
|
||||
seed_all(2021 + local_rank)
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data):
|
||||
# zero-dp forward
|
||||
zero1_output = zero1_model(cur_data)
|
||||
zero2_output = zero2_model(cur_data)
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
zero1_optimizer.sync_grad()
|
||||
zero2_optimizer.sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1)
|
||||
fwd_bwd_func(1, input_data2)
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
zero2_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_acc():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
grad_scale = 32
|
||||
seed_all(2008)
|
||||
|
||||
# create models
|
||||
zero_model = TestModel()
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda()
|
||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
|
||||
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=False,
|
||||
initial_scale=grad_scale,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
||||
# create data
|
||||
seed_all(2022 + local_rank)
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(cur_data)
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
# torch-ddp backward
|
||||
torch_output.sum().backward()
|
||||
|
||||
if check_flag:
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
unscale_grad = z1p.grad / grad_scale
|
||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, unscale_grad)
|
||||
|
||||
zero_optimizer.sync_grad()
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
fwd_bwd_func(1, input_data2, False)
|
||||
|
||||
zero_optimizer.step()
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
|
||||
assert_close(p.data, z1p.data)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_grad_acc()
|
||||
# exam_zero_1_2_grad_acc()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_grad_accumulation():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_grad_accumulation()
|
@@ -1,161 +0,0 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def check_equal(a, b, rtol=1e-4, atol=1e-3):
|
||||
"""
|
||||
This function checks if two tensors are equal within tolerance
|
||||
"""
|
||||
assert torch.allclose(a.float(), b.float(), rtol=rtol, atol=atol), f'a = {a}, b = {b}'
|
||||
|
||||
|
||||
def check_completely_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are completely equal
|
||||
"""
|
||||
assert torch.all(a == b), f'a = {a}, b = {b}'
|
||||
|
||||
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def exam_zero_1_2_grad_clip():
|
||||
# create model
|
||||
zero1_model = TestModel().cuda().half()
|
||||
zero2_model = copy.deepcopy(zero1_model)
|
||||
|
||||
# create optimizer
|
||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=0.001)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=0.001)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda().half()
|
||||
|
||||
# forward
|
||||
zero1_output = zero1_model(input_data)
|
||||
zero2_output = zero2_model(input_data)
|
||||
check_completely_equal(zero1_output, zero2_output)
|
||||
|
||||
# backward
|
||||
zero1_optimizer.backward(zero1_output.mean().float())
|
||||
zero2_optimizer.backward(zero2_output.mean().float())
|
||||
|
||||
# check grad
|
||||
# as this param is small, the backward reduction
|
||||
# will not be fired
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
check_completely_equal(z1p.grad, z2p.grad)
|
||||
|
||||
# step
|
||||
zero1_optimizer.sync_grad()
|
||||
zero2_optimizer.sync_grad()
|
||||
|
||||
# step
|
||||
zero1_optimizer.step()
|
||||
zero2_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
check_completely_equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_grad_clip():
|
||||
# create models
|
||||
zero_model = TestModel()
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
torch_model = DDP(torch_model.cuda())
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
|
||||
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.half())
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
check_equal(zero_output, torch_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
|
||||
# torch-ddp backward
|
||||
torch_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
|
||||
check_equal(p.grad, z1p.grad)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.sync_grad()
|
||||
zero_optimizer.step()
|
||||
|
||||
# torch ddp step
|
||||
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
for p, z1p in zip(torch_model.parameters(), zero_model.parameters()):
|
||||
check_equal(p.data, z1p.data, atol=5e-4)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
exam_zero_1_2_grad_clip()
|
||||
exam_zero_1_grad_clip()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_grad_clip():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_grad_clip()
|
@@ -6,27 +6,41 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
def check_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are equal within tolerance
|
||||
"""
|
||||
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
|
||||
class TestModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
self.linear1 = nn.Linear(128, 256)
|
||||
self.linear2 = nn.Linear(256, 512)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def check_completely_equal(a, b):
|
||||
"""
|
||||
This function checks if two tensors are completely equal
|
||||
"""
|
||||
assert torch.all(a == b), f'a = {a}, b = {b}'
|
||||
def half_close(a, b, loose=False):
|
||||
rtol = None
|
||||
atol = None
|
||||
if loose:
|
||||
rtol = 5e-2
|
||||
atol = 5e-4
|
||||
|
||||
a = a.detach().half()
|
||||
b = b.detach().half()
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def check_sharded_param_consistency():
|
||||
def exam_zero_1_2():
|
||||
"""
|
||||
In this test, we want to test whether zero stage 1 and 2
|
||||
deliver the same numerical results despite different communication
|
||||
@@ -37,67 +51,54 @@ def check_sharded_param_consistency():
|
||||
pg: partition gradients and optimizer states
|
||||
|
||||
"""
|
||||
|
||||
# create layers
|
||||
oss_linear1 = nn.Linear(128, 256)
|
||||
oss_linear2 = nn.Linear(256, 512)
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(2001)
|
||||
|
||||
# create model
|
||||
oss_model = nn.Sequential(oss_linear1, oss_linear2)
|
||||
pg_model = copy.deepcopy(oss_model)
|
||||
|
||||
oss_model = oss_model.cuda().half()
|
||||
pg_model = pg_model.cuda().half()
|
||||
zero1_model = TestModel().cuda()
|
||||
zero2_model = copy.deepcopy(zero1_model)
|
||||
|
||||
# create optimizer
|
||||
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
|
||||
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
|
||||
oss_optimizer = LowLevelZeroOptimizer(oss_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
pg_optimizer = LowLevelZeroOptimizer(pg_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
|
||||
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
|
||||
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=128,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=128)
|
||||
# create data
|
||||
seed_all(2001 + local_rank)
|
||||
input_data = torch.randn(32, 128).cuda()
|
||||
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda().half()
|
||||
zero1_output = zero1_model(input_data)
|
||||
zero2_output = zero2_model(input_data)
|
||||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# forward
|
||||
oss_output = oss_model(input_data)
|
||||
pg_output = pg_model(input_data)
|
||||
check_completely_equal(oss_output, pg_output)
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.mean().float())
|
||||
zero2_optimizer.backward(zero2_output.mean().float())
|
||||
|
||||
# backward
|
||||
oss_optimizer.backward(oss_output.mean().float())
|
||||
pg_optimizer.backward(pg_output.mean().float())
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
if z2p.grad is not None:
|
||||
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
|
||||
assert torch.equal(z1p.grad, z2p.grad)
|
||||
|
||||
# check grad
|
||||
# as this param is small, the backward reduction
|
||||
# will not be fired
|
||||
oss_linear1_grad = oss_model[0].weight.grad
|
||||
oss_linear2_grad = oss_model[1].weight.grad
|
||||
pg_linear1_grad = pg_model[0].weight.grad
|
||||
pg_linear2_grad = pg_model[1].weight.grad
|
||||
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
|
||||
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
|
||||
zero1_optimizer.sync_grad()
|
||||
zero2_optimizer.sync_grad()
|
||||
|
||||
# step
|
||||
oss_optimizer.sync_grad()
|
||||
pg_optimizer.sync_grad()
|
||||
|
||||
# step
|
||||
oss_optimizer.step()
|
||||
pg_optimizer.step()
|
||||
zero1_optimizer.step()
|
||||
zero2_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_completely_equal(oss_model[0].weight, pg_model[0].weight)
|
||||
check_completely_equal(oss_model[1].weight, pg_model[1].weight)
|
||||
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
|
||||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def check_sharded_optim_against_torch_ddp():
|
||||
def exam_zero_1_torch_ddp():
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
@@ -106,20 +107,22 @@ def check_sharded_optim_against_torch_ddp():
|
||||
We feed these two sets of models with the same input and check if the
|
||||
differences in model output and updated parameters are within tolerance.
|
||||
"""
|
||||
local_rank = torch.distributed.get_rank()
|
||||
seed_all(1453)
|
||||
|
||||
# create layer
|
||||
zero_linear1 = nn.Linear(128, 256)
|
||||
zero_linear2 = nn.Linear(256, 512)
|
||||
|
||||
# create model
|
||||
zero_model = nn.Sequential(zero_linear1, zero_linear2)
|
||||
# create models
|
||||
zero_model = TestModel()
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
torch_model = DDP(torch_model.cuda())
|
||||
# torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
||||
torch_model = torch_model.cuda()
|
||||
|
||||
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# half_close(p.data, z1p.data)
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001)
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
||||
# we only test stage 1 here
|
||||
# in `check_sharded_param_consistency.py`, we will test whether
|
||||
@@ -127,10 +130,11 @@ def check_sharded_optim_against_torch_ddp():
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=True,
|
||||
initial_scale=1,
|
||||
clip_grad_norm=0.0)
|
||||
reduce_bucket_size=262144)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
|
||||
|
||||
seed_all(1453 + local_rank)
|
||||
# create
|
||||
input_data = torch.rand(32, 128).cuda()
|
||||
|
||||
@@ -139,7 +143,7 @@ def check_sharded_optim_against_torch_ddp():
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
check_equal(zero_output, torch_output)
|
||||
half_close(zero_output, torch_output, loose=True)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float())
|
||||
@@ -148,12 +152,8 @@ def check_sharded_optim_against_torch_ddp():
|
||||
torch_output.mean().backward()
|
||||
|
||||
# check grad
|
||||
zero_linear1_grad = zero_model[0].weight.grad
|
||||
zero_linear2_grad = zero_model[1].weight.grad
|
||||
torch_linear1_grad = torch_model.module[0].weight.grad
|
||||
torch_linear2_grad = torch_model.module[1].weight.grad
|
||||
check_equal(zero_linear1_grad, torch_linear1_grad)
|
||||
check_equal(zero_linear2_grad, torch_linear2_grad)
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
half_close(p.grad, z1p.grad, loose=True)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer.sync_grad()
|
||||
@@ -163,23 +163,24 @@ def check_sharded_optim_against_torch_ddp():
|
||||
torch_optimizer.step()
|
||||
|
||||
# check updated param
|
||||
check_equal(zero_model[0].weight, torch_model.module[0].weight)
|
||||
check_equal(zero_model[1].weight, torch_model.module[1].weight)
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, torch.max(torch.abs(p.data - z1p.data)))
|
||||
half_close(p.data, z1p.data, loose=True)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
check_sharded_optim_against_torch_ddp()
|
||||
check_sharded_param_consistency()
|
||||
exam_zero_1_torch_ddp()
|
||||
exam_zero_1_2()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim():
|
||||
def test_zero_1_2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim()
|
||||
test_zero_1_2()
|
||||
|
Reference in New Issue
Block a user