[test] reorganize zero/gemini tests (#3445)

This commit is contained in:
ver217
2023-04-06 09:38:25 +08:00
committed by GitHub
parent 72cb4dd433
commit 933048ad3e
34 changed files with 7 additions and 8 deletions

View File

@@ -0,0 +1,167 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
# create model
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data):
# zero-dp forward
zero1_output = zero1_model(cur_data)
zero2_output = zero2_model(cur_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)
# torch-ddp forward
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
# torch-ddp backward
torch_output.sum().backward()
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
unscale_grad = z1p.grad / grad_scale
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
zero_optimizer._sync_grad()
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
zero_optimizer.step()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
torch_optimizer.step()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data)))
assert_close(p.data, z1p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
exam_zero_1_2_grad_acc()
@pytest.mark.dist
def test_grad_accumulation():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_grad_accumulation()

View File

@@ -0,0 +1,188 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.random import seed_all
from colossalai.utils import free_port
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def half_close(a, b, loose=False):
rtol = None
atol = None
if loose:
rtol = 5e-2
atol = 5e-4
a = a.detach().half()
b = b.detach().half()
assert_close(a, b, rtol=rtol, atol=atol)
def exam_zero_1_2():
"""
In this test, we want to test whether zero stage 1 and 2
deliver the same numerical results despite different communication
pattern
we use these prefixes to differentiate the zero stage
oss: partition optimizer states
pg: partition gradients and optimizer states
"""
local_rank = torch.distributed.get_rank()
seed_all(2001)
# create model
zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer,
overlap_communication=True,
initial_scale=128,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=128)
# create data
seed_all(2001 + local_rank)
input_data = torch.randn(32, 128).cuda()
zero1_output = zero1_model(input_data)
zero2_output = zero2_model(input_data)
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
zero1_optimizer._sync_grad()
zero2_optimizer._sync_grad()
# step
zero1_optimizer.step()
zero2_optimizer.step()
# check updated param
for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()):
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_torch_ddp():
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
2. torch: use torch DDP and fp32 parameters
We feed these two sets of models with the same input and check if the
differences in model output and updated parameters are within tolerance.
"""
local_rank = torch.distributed.get_rank()
seed_all(1453)
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# half_close(p.data, z1p.data)
# create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
# we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
# create
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
# torch-ddp forward
torch_output = torch_model(input_data)
half_close(zero_output, torch_output, loose=True)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
# torch-ddp backward
torch_output.mean().backward()
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
half_close(p.grad, z1p.grad, loose=True)
# zero-dp step
zero_optimizer._sync_grad()
zero_optimizer.step()
# torch ddp step
torch_optimizer.step()
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
half_close(p.data, z1p.data, loose=True)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp()
exam_zero_1_2()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_1_2():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_1_2()

View File

@@ -0,0 +1,60 @@
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.utils import free_port, get_current_device
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def exam_zero_init():
dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
model1 = MlpModel().cuda()
with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
model2 = MlpModel()
optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
mp_group1 = optimizer1._mp_torch_group
mp_group2 = optimizer2._mp_torch_group
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
def run_dist(rank, world_size, port):
config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_init()
@pytest.mark.dist
def test_zero_init():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_init()

View File

@@ -0,0 +1,98 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.tensor import ProcessGroup
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer
from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(32, 128)
self.act = nn.GELU()
self.linear2 = nn.Linear(128, 32)
def forward(self, x):
y = self.linear1(x)
y = self.act(y)
y = self.linear2(y)
return x + y
@parameterize("overlap_flag", [False, True])
@parameterize("partition_flag", [False, True])
def exam_zero_with_tp(overlap_flag, partition_flag):
set_seed(233010)
tp_pg = ProcessGroup(tp_degree=2)
with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
hybrid_model = MlpModel()
torch_model = MlpModel().cuda()
for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
pt.data.copy_(ph.data)
for name, param in hybrid_model.named_parameters():
if 'linear1' in name:
split_param_row_tp1d(param, tp_pg)
param.compute_spec.set_output_replicate(False)
if 'linear2.weight' in name:
split_param_col_tp1d(param, tp_pg)
torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11
hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2)
hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)
data = torch.randn(8, 32, device=get_current_device())
torch_loss = torch_model(data).sum()
hybrid_loss = hybrid_model(data).sum()
assert_close(torch_loss, hybrid_loss)
torch_loss.backward()
torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
hybrid_optim.backward(hybrid_loss)
torch_optim.step()
hybrid_optim.step()
for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
assert strict_shard_equal(pt.data, ph.data, tp_pg)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_with_tp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_with_tp():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_with_tp()