mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[checkpoint] use gather_tensor in checkpoint and update its unit test (#1339)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os, shutil
|
||||
import torch
|
||||
import pytest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
@@ -15,8 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
|
||||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
|
||||
@@ -63,8 +63,8 @@ def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)
|
||||
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
|
||||
assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
|
||||
|
||||
|
||||
def remove(path):
|
||||
@@ -84,9 +84,13 @@ def compare_optims(optim1, optim2):
|
||||
if k not in state2:
|
||||
continue
|
||||
p2 = state2[k]
|
||||
if isinstance(p1, ColoTensor):
|
||||
assert isinstance(p2, ColoTensor)
|
||||
assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)
|
||||
for n, t1 in p1.items():
|
||||
if n not in p2:
|
||||
continue
|
||||
t2 = p2[n]
|
||||
if isinstance(t1, ColoTensor):
|
||||
assert isinstance(t2, ColoTensor)
|
||||
assert torch.allclose(t1, t2, rtol=0, atol=0)
|
||||
|
||||
|
||||
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
|
||||
@@ -99,7 +103,6 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
# set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
model_reload = model_builder(checkpoint=True)
|
||||
|
||||
if use_mp_reload:
|
||||
if 'bert' == model_name:
|
||||
@@ -119,25 +122,26 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
elif p.process_group.tp_world_size() == 1:
|
||||
p.redistribute(ReplicaSpec(), pg)
|
||||
p.set_process_group(pg)
|
||||
elif "simple_net" == model_name:
|
||||
init_spec_func(model, pg)
|
||||
|
||||
model_reload = deepcopy(model)
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
model.eval()
|
||||
|
||||
model_reload = model_reload.cuda()
|
||||
model_reload.train()
|
||||
model_reload.eval()
|
||||
|
||||
opt_class = torch.optim.Adam
|
||||
colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
|
||||
colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
|
||||
run_reload = False
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
colo_optimizer_reload.zero_grad()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
@@ -155,43 +159,33 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
|
||||
loss.backward()
|
||||
loss_reload.backward()
|
||||
|
||||
if run_reload:
|
||||
colo_optimizer_reload.zero_grad()
|
||||
if criterion:
|
||||
output_reload = model_reload(data)
|
||||
loss_reload = criterion(output_reload, label)
|
||||
else:
|
||||
loss_reload = model_reload(data, label)
|
||||
loss_reload.backward()
|
||||
colo_optimizer_reload.step()
|
||||
colo_optimizer.step()
|
||||
colo_optimizer_reload.step()
|
||||
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
if not os.path.isdir('./checkpoint') and rank == 0:
|
||||
os.mkdir('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
|
||||
dist.barrier()
|
||||
load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
|
||||
dist.barrier()
|
||||
|
||||
# Since model is sharded, we merge them before param checking.
|
||||
for p in model.parameters():
|
||||
p.to_replicate_()
|
||||
|
||||
for p in model_reload.parameters():
|
||||
p.to_replicate_()
|
||||
|
||||
check_param_equal(model, model_reload)
|
||||
compare_optims(colo_optimizer, colo_optimizer_reload)
|
||||
|
||||
if rank == 0:
|
||||
remove('./checkpoint')
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
for model_name in ['simple_net', 'bert']:
|
||||
# TODO(haichen) add BERT in the test
|
||||
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
|
||||
for model_name in ['simple_net']:
|
||||
_run_checkpoint(model_name,
|
||||
init_1d_row_for_linear_weight_spec,
|
||||
use_ddp,
|
||||
|
47
tests/test_utils/test_colo_checkpoint_tools.py
Normal file
47
tests/test_utils/test_colo_checkpoint_tools.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
|
||||
from tests.test_tensor._utils import tensor_shard_equal
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||
x = torch.randn(4, 4, device=get_current_device())
|
||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
gather_tensor(param)
|
||||
if dist.get_rank() == 0:
|
||||
assert torch.allclose(x, param.data, rtol=0, atol=0)
|
||||
else:
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
dist.barrier()
|
||||
|
||||
scatter_tensor(param, spec[0])
|
||||
assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert param.requires_grad is True
|
||||
dist.barrier()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_checkpoint(world_size=4)
|
Reference in New Issue
Block a user