[test] remove useless tests (#4359)

* [test] remove legacy zero test

* [test] remove lazy distribute test

* [test] remove outdated checkpoint io
This commit is contained in:
Hongxin Liu
2023-08-01 18:52:14 +08:00
committed by GitHub
parent 16c0acc01b
commit 16bf4c0221
31 changed files with 0 additions and 3355 deletions

View File

@@ -1,120 +0,0 @@
import torch
import torch.nn as nn
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.utils import build_checkpoints
from torch.optim import Adam
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def test_global_model():
model = DummyModel()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 0
assert meta['dist_meta'] is None
orig_state_dict = model.state_dict()
global_state_dict = model_checkpoints[0]
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
for k, v in orig_state_dict.items():
assert torch.equal(v, global_state_dict[k])
def test_global_model_shard():
model = DummyModel()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
assert len(model_checkpoints) == 2
assert len(optimizer_checkpoints) == 0
assert meta['dist_meta'] is None
orig_state_dict = model.state_dict()
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
for k, v in orig_state_dict.items():
for state_dict in model_checkpoints:
if k in state_dict:
assert torch.equal(v, state_dict[k])
def test_global_optimizer():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
assert len(optimizer_checkpoints) == 1
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
for state in meta['paired_os'].values():
for k, is_paired in state.items():
if k == 'step':
assert not is_paired
else:
assert is_paired
orig_state_dict = optimizer.state_dict()
state_dict = optimizer_checkpoints[0]
for k, orig_state in orig_state_dict['state'].items():
state = state_dict['state'][k]
for v1, v2 in zip(orig_state.values(), state.values()):
if isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2)
else:
assert v2 == v2
assert orig_state_dict['param_groups'] == state_dict['param_groups']
def test_global_optimizer_shard():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
assert len(optimizer_checkpoints) == 2
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
orig_state_dict = optimizer.state_dict()
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
optimizer_checkpoints[1]['state'].keys())
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
for k, orig_state in orig_state_dict['state'].items():
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
'state'] else optimizer_checkpoints[1]['state'][k]
for v1, v2 in zip(orig_state.values(), state.values()):
if isinstance(v2, torch.Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
def test_dist_model_optimizer():
model = DummyModel()
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
assert dist_meta == meta['dist_meta']
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 1
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
assert dist_meta == meta['dist_meta']
assert len(model_checkpoints) == 1
assert len(optimizer_checkpoints) == 1
if __name__ == '__main__':
test_global_model()
test_global_model_shard()
test_global_optimizer()
test_global_optimizer_shard()
test_dist_model_optimizer()

View File

@@ -1,186 +0,0 @@
from copy import deepcopy
from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from torch.optim import Adam, Optimizer
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
for k, v in a.items():
assert torch.equal(v, b[k])
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
assert set(a['state'].keys()) == set(b['state'].keys())
for k, state in a['state'].items():
b_state = b['state'][k]
for v1, v2 in zip(state.values(), b_state.values()):
if isinstance(v1, Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
if not ignore_param_groups:
assert a['param_groups'] == b['param_groups']
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.rand_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
with torch.no_grad():
for p in model.parameters():
p.fill_(scalar)
for state in optimizer.state.values():
for v in state.values():
if isinstance(v, Tensor):
v.fill_(scalar)
def get_dist_metas(nprocs: int, zero: bool = False):
dp_world_size = nprocs // 2
dist_metas = []
for rank in range(nprocs):
if zero:
dist_metas.append({
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
})
else:
dist_metas.append({
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
})
return dist_metas
def get_redist_meta(nprocs: int):
dp_world_size = nprocs // 2
rank_meta = {
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
}
param_meta = {
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamRedistMeta(dp_world_size, 1)
}
return RedistMeta(rank_meta, [], param_meta)
@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
def test_save_global_load_global(max_shard_size_gb: float):
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
new_model, new_optimizer = prepare_model_optim()
load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
check_model_state_dict(model.state_dict(), new_model.state_dict())
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_fn()
def launch_dist(fn, world_size: int):
spawn(run_dist, world_size, test_fn=fn)
def save_dist(dir_name: str, zero: bool):
model, optimizer = prepare_model_optim(shard=True, zero=zero)
reset_model_optim(model, optimizer)
world_size = dist.get_world_size()
rank = dist.get_rank()
save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank])
def load_and_check_dist(dir_name: str):
world_size = dist.get_world_size()
model, optimizer = prepare_model_optim(shard=True)
reset_model_optim(model, optimizer)
model_state_dict = deepcopy(model.state_dict())
optimizer_state_dict = deepcopy(optimizer.state_dict())
reset_model_optim(model, optimizer, 1)
load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size))
check_model_state_dict(model_state_dict, model.state_dict())
check_optim_state_dict(optimizer_state_dict, optimizer.state_dict())
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_global_load_dist():
model, optimizer = prepare_model_optim()
reset_model_optim(model, optimizer)
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_dist_load_dist():
with TemporaryDirectory() as dir_name:
# save tp + dp
fn = partial(save_dist, dir_name, False)
launch_dist(fn, 2)
# load tp + dp
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 2)
with TemporaryDirectory() as dir_name:
# save tp + zero
fn = partial(save_dist, dir_name, True)
launch_dist(fn, 4)
# load tp + dp
fn = partial(load_and_check_dist, dir_name)
launch_dist(fn, 2)
launch_dist(fn, 4)
if __name__ == '__main__':
test_save_global_load_global(80 / 1024**3)
test_save_global_load_global(0)
test_save_global_load_dist()
test_save_dist_load_dist()

View File

@@ -1,126 +0,0 @@
import os
from functools import partial
from tempfile import TemporaryDirectory
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import merge, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def test_merge_global():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 0
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 0
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
'size': 2
}
}},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
test_fn()
def run_save_dist(dir_name: str, zero: bool):
model, optimizer = prepare_model_optim(shard=True, zero=zero)
rank = dist.get_rank()
dp_world_size = dist.get_world_size() // 2
if not zero:
dist_metas = {
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
}
else:
dist_metas = {
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
}
save(dir_name, model, optimizer, dist_meta=dist_metas)
@pytest.mark.dist
@pytest.mark.parametrize("zero", [False, True])
@rerun_if_address_is_in_use()
def test_merge_tp_dp(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
spawn(run_dist, world_size, test_fn=fn)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 5
global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 1
meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
assert meta['dist_meta'] is None
assert len(meta['params']) == 2
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
assert len(model_state_dict) == 2
assert model_state_dict['fc.weight'].size(1) == 20
optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
if __name__ == '__main__':
test_merge_global()
test_merge_tp_dp(False)
test_merge_tp_dp(True)

View File

@@ -1,101 +0,0 @@
import torch
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
def test_unflatten_zero_param_even() -> None:
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).chunk(4))
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
assert torch.equal(orig_tensor, unflattened_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_unflatten_zero_param_uneven() -> None:
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
assert torch.equal(orig_tensor, unflattened_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_1d_row() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_1d_col() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_2d() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_gather_tp_param_2d_reverse() -> None:
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
gathered_tensor = gather_tp_param(tensors, dist_metas)
assert torch.equal(orig_tensor, gathered_tensor)
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_merge_param_hybrid() -> None:
dist_metas = [
ParamDistMeta(i % 2,
2,
i // 2,
6,
tp_shard_dims=[1, 0],
tp_num_parts=[3, 2],
zero_numel=4,
zero_orig_shape=[2, 2]) for i in range(12)
]
orig_tensor = torch.rand(4, 6)
tensors = [
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
for chunk in t.contiguous().reshape(-1).split([1, 3])
]
merged_tensor = merge_param(tensors, dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
def test_merge_param_dummy() -> None:
dist_metas = [ParamDistMeta(0, 1, 0, 1)]
orig_tensor = torch.rand(4, 6)
merged_tensor = merge_param([orig_tensor], dist_metas)
assert torch.equal(orig_tensor, merged_tensor)
if __name__ == '__main__':
test_unflatten_zero_param_even()
test_unflatten_zero_param_uneven()
test_gather_tp_param_1d_row()
test_gather_tp_param_1d_col()
test_gather_tp_param_2d()
test_gather_tp_param_2d_reverse()
test_merge_param_hybrid()
test_merge_param_dummy()

View File

@@ -1,152 +0,0 @@
import os
from functools import partial
from tempfile import TemporaryDirectory
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import redist, save
from colossalai.utils.checkpoint_io.meta import (
ParamDistMeta,
ParamRedistMeta,
PipelineRedistMeta,
RankRedistMeta,
RedistMeta,
)
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim(shard: bool = False, zero: bool = False):
model = DummyModel()
if shard:
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
if zero:
dp_rank = dist.get_rank() // 2
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
if dp_rank != 0:
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def get_dist_metas(nprocs: int, zero: bool = False):
dp_world_size = nprocs // 2
dist_metas = []
for rank in range(nprocs):
if zero:
dist_metas.append({
'fc.weight':
ParamDistMeta(rank // 2,
dp_world_size,
rank % 2,
2,
tp_shard_dims=[1],
tp_num_parts=[2],
zero_numel=10,
zero_orig_shape=[1, 10]),
'fc.bias':
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
})
else:
dist_metas.append({
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
})
return dist_metas
def get_redist_meta(nprocs: int):
dp_world_size = nprocs // 2
rank_meta = {
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
}
param_meta = {
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
'fc.bias': ParamRedistMeta(dp_world_size, 1)
}
return RedistMeta(rank_meta, [], param_meta)
def check_checkpoint_shape(dir_name: str):
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
for meta_name in global_meta['meta']:
meta = torch.load(os.path.join(dir_name, meta_name))
assert meta['dist_meta'] is not None
assert len(meta['params']) == 2
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
assert len(model_state_dict) == 2
assert model_state_dict['fc.weight'].size(1) == 10
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
def test_global_to_dist():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
check_checkpoint_shape(output_dir)
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
'size': 2
}
}},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
test_fn()
def run_save_dist(dir_name: str, zero: bool):
model, optimizer = prepare_model_optim(shard=True, zero=zero)
rank = dist.get_rank()
save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank])
@pytest.mark.dist
@pytest.mark.parametrize("zero", [False, True])
@rerun_if_address_is_in_use()
def test_dist_to_dist(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
spawn(run_dist, world_size, test_fn=fn)
with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
if not zero:
assert len(os.listdir(output_dir)) == 0
else:
check_checkpoint_shape(output_dir)
if __name__ == '__main__':
test_global_to_dist()
test_dist_to_dist(False)
test_dist_to_dist(True)

View File

@@ -1,149 +0,0 @@
import os
from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import (
GLOBAL_META_FILE_NAME,
META_CKPT_FILE_NAME,
MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME,
)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
for k, v in a.items():
assert torch.equal(v, b[k])
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
assert set(a['state'].keys()) == set(b['state'].keys())
for k, state in a['state'].items():
b_state = b['state'][k]
for v1, v2 in zip(state.values(), b_state.values()):
if isinstance(v1, Tensor):
assert torch.equal(v1, v2)
else:
assert v1 == v2
if not ignore_param_groups:
assert a['param_groups'] == b['param_groups']
class DummyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(20, 1)
def prepare_model_optim():
model = DummyModel()
for p in model.parameters():
p.grad = torch.ones_like(p)
optimizer = Adam(model.parameters(), lr=1e-3)
optimizer.step()
return model, optimizer
def test_overwrite():
model = DummyModel()
with TemporaryDirectory() as dir_name:
with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
pass
with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
save(dir_name, model)
def test_save_global():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer)
assert len(os.listdir(dir_name)) == 5
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
assert len(meta['model']) == 1
assert len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
check_model_state_dict(model.state_dict(), model_state_dict)
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
assert len(other_state_dict) == 0
def test_save_global_shard():
model, optimizer = prepare_model_optim()
with TemporaryDirectory() as dir_name:
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
assert len(os.listdir(dir_name)) == 7
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
check_optim_state_dict(
optimizer.state_dict(), {
'state': {
**optimizer_state_dicts[0]['state'],
**optimizer_state_dicts[1]['state']
},
'param_groups': optimizer_state_dicts[0]['param_groups']
})
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_fn()
def run_save_dist(dir_name):
model, optimizer = prepare_model_optim()
dist_metas = {
'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
}
save(dir_name, model, optimizer, dist_meta=dist_metas)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_save_dist():
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name)
world_size = 2
spawn(run_dist, world_size, test_fn=fn)
assert len(os.listdir(dir_name)) == 8
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 2
for rank, meta_name in enumerate(global_meta['meta']):
meta = torch.load(os.path.join(dir_name, meta_name))
assert meta.get('dist_meta', None) is not None
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
assert len(model_state_dict) == 2
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
assert len(optimizer_state_dict['state']) == 2
assert 'param_groups' in optimizer_state_dict
if __name__ == '__main__':
test_overwrite()
test_save_global()
test_save_global_shard()
test_save_dist()

View File

@@ -1,137 +0,0 @@
import torch
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
def test_flatten_zero_param_even() -> None:
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).chunk(4))
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
assert len(tensors) == len(flat_tensors)
for t, st in zip(tensors, flat_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1
unmerged_tensors = unmerged_tensors[0]
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert torch.equal(t, tl)
def test_flatten_zero_param_uneven() -> None:
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
orig_tensor = torch.rand(4, 4)
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
flat_tensors = flat_tensors[1:-1]
assert len(tensors) == len(flat_tensors)
for t, st in zip(tensors, flat_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1
unmerged_tensors = unmerged_tensors[0]
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
unmerged_tensors = unmerged_tensors[1:-1]
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert torch.equal(t, tl)
def test_split_tp_param_1d_row() -> None:
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_1d_col() -> None:
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
orig_tensor = torch.rand(4, 4)
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_2d() -> None:
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_split_tp_param_2d_reverse() -> None:
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
orig_tensor = torch.rand(4, 6)
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
split_tensors = split_tp_param(orig_tensor, redist_meta)
assert len(tensors) == len(split_tensors)
for t, st in zip(tensors, split_tensors):
assert torch.equal(t, st)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(tensors) == len(unmerged_tensors)
for t, tl in zip(tensors, unmerged_tensors):
assert len(tl) == 1
assert torch.equal(t, tl[0])
def test_unmerge_param_hybrid() -> None:
redist_meta = ParamRedistMeta(2,
6,
tp_shard_dims=[1, 0],
tp_num_parts=[3, 2],
zero_start_dp_rank=0,
zero_offsets=[0, 1])
orig_tensor = torch.rand(4, 6)
tensors = [
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
for chunk in t.contiguous().reshape(-1).split([1, 3])
]
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
for tp_rank in range(6):
for dp_rank in range(2):
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
def test_unmerge_param_dummy() -> None:
redist_meta = ParamRedistMeta(1, 1)
orig_tensor = torch.rand(4, 6)
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
if __name__ == '__main__':
test_flatten_zero_param_even()
test_flatten_zero_param_uneven()
test_split_tp_param_1d_row()
test_split_tp_param_1d_col()
test_split_tp_param_2d()
test_split_tp_param_2d_reverse()
test_unmerge_param_hybrid()
test_unmerge_param_dummy()