mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[test] remove useless tests (#4359)
* [test] remove legacy zero test * [test] remove lazy distribute test * [test] remove outdated checkpoint io
This commit is contained in:
@@ -1,120 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.utils import build_checkpoints
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def test_global_model():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
global_state_dict = model_checkpoints[0]
|
||||
assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
|
||||
for k, v in orig_state_dict.items():
|
||||
assert torch.equal(v, global_state_dict[k])
|
||||
|
||||
|
||||
def test_global_model_shard():
|
||||
model = DummyModel()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
|
||||
assert len(model_checkpoints) == 2
|
||||
assert len(optimizer_checkpoints) == 0
|
||||
assert meta['dist_meta'] is None
|
||||
orig_state_dict = model.state_dict()
|
||||
assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
|
||||
assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
|
||||
for k, v in orig_state_dict.items():
|
||||
for state_dict in model_checkpoints:
|
||||
if k in state_dict:
|
||||
assert torch.equal(v, state_dict[k])
|
||||
|
||||
|
||||
def test_global_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
|
||||
for state in meta['paired_os'].values():
|
||||
for k, is_paired in state.items():
|
||||
if k == 'step':
|
||||
assert not is_paired
|
||||
else:
|
||||
assert is_paired
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
state_dict = optimizer_checkpoints[0]
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = state_dict['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v2 == v2
|
||||
assert orig_state_dict['param_groups'] == state_dict['param_groups']
|
||||
|
||||
|
||||
def test_global_optimizer_shard():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
|
||||
assert len(optimizer_checkpoints) == 2
|
||||
assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
|
||||
orig_state_dict = optimizer.state_dict()
|
||||
assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
|
||||
optimizer_checkpoints[1]['state'].keys())
|
||||
assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
|
||||
for k, orig_state in orig_state_dict['state'].items():
|
||||
state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
|
||||
'state'] else optimizer_checkpoints[1]['state'][k]
|
||||
for v1, v2 in zip(orig_state.values(), state.values()):
|
||||
if isinstance(v2, torch.Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
|
||||
assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']
|
||||
|
||||
|
||||
def test_dist_model_optimizer():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
|
||||
assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
|
||||
dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
|
||||
model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
|
||||
assert dist_meta == meta['dist_meta']
|
||||
assert len(model_checkpoints) == 1
|
||||
assert len(optimizer_checkpoints) == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_model()
|
||||
test_global_model_shard()
|
||||
test_global_optimizer()
|
||||
test_global_optimizer_shard()
|
||||
test_dist_model_optimizer()
|
@@ -1,186 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.io import load, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_groups:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0):
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
p.fill_(scalar)
|
||||
for state in optimizer.state.values():
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
v.fill_(scalar)
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0])
|
||||
def test_save_global_load_global(max_shard_size_gb: float):
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
new_model, new_optimizer = prepare_model_optim()
|
||||
load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb)
|
||||
check_model_state_dict(model.state_dict(), new_model.state_dict())
|
||||
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def launch_dist(fn, world_size: int):
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
|
||||
|
||||
def save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
reset_model_optim(model, optimizer)
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank])
|
||||
|
||||
|
||||
def load_and_check_dist(dir_name: str):
|
||||
world_size = dist.get_world_size()
|
||||
model, optimizer = prepare_model_optim(shard=True)
|
||||
reset_model_optim(model, optimizer)
|
||||
model_state_dict = deepcopy(model.state_dict())
|
||||
optimizer_state_dict = deepcopy(optimizer.state_dict())
|
||||
reset_model_optim(model, optimizer, 1)
|
||||
load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size))
|
||||
check_model_state_dict(model_state_dict, model.state_dict())
|
||||
check_optim_state_dict(optimizer_state_dict, optimizer.state_dict())
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_global_load_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
reset_model_optim(model, optimizer)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist_load_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + dp
|
||||
fn = partial(save_dist, dir_name, False)
|
||||
launch_dist(fn, 2)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
with TemporaryDirectory() as dir_name:
|
||||
# save tp + zero
|
||||
fn = partial(save_dist, dir_name, True)
|
||||
launch_dist(fn, 4)
|
||||
# load tp + dp
|
||||
fn = partial(load_and_check_dist, dir_name)
|
||||
launch_dist(fn, 2)
|
||||
launch_dist(fn, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_save_global_load_global(80 / 1024**3)
|
||||
test_save_global_load_global(0)
|
||||
test_save_global_load_dist()
|
||||
test_save_dist_load_dist()
|
@@ -1,126 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import merge, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_merge_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
dp_world_size = dist.get_world_size() // 2
|
||||
if not zero:
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
}
|
||||
else:
|
||||
dist_metas = {
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
}
|
||||
save(dir_name, model, optimizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_merge_tp_dp(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 5
|
||||
global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1
|
||||
meta = torch.load(os.path.join(output_dir, global_meta['meta'][0]))
|
||||
assert meta['dist_meta'] is None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 20
|
||||
optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_merge_global()
|
||||
test_merge_tp_dp(False)
|
||||
test_merge_tp_dp(True)
|
@@ -1,101 +0,0 @@
|
||||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param
|
||||
|
||||
|
||||
def test_unflatten_zero_param_even() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_unflatten_zero_param_uneven() -> None:
|
||||
dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
unflattened_tensor = unflatten_zero_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, unflattened_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_row() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_1d_col() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)]
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_gather_tp_param_2d_reverse() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
gathered_tensor = gather_tp_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, gathered_tensor)
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_hybrid() -> None:
|
||||
dist_metas = [
|
||||
ParamDistMeta(i % 2,
|
||||
2,
|
||||
i // 2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_numel=4,
|
||||
zero_orig_shape=[2, 2]) for i in range(12)
|
||||
]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
merged_tensor = merge_param(tensors, dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
def test_merge_param_dummy() -> None:
|
||||
dist_metas = [ParamDistMeta(0, 1, 0, 1)]
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
merged_tensor = merge_param([orig_tensor], dist_metas)
|
||||
assert torch.equal(orig_tensor, merged_tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_unflatten_zero_param_even()
|
||||
test_unflatten_zero_param_uneven()
|
||||
test_gather_tp_param_1d_row()
|
||||
test_gather_tp_param_1d_col()
|
||||
test_gather_tp_param_2d()
|
||||
test_gather_tp_param_2d_reverse()
|
||||
test_merge_param_hybrid()
|
||||
test_merge_param_dummy()
|
@@ -1,152 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import redist, save
|
||||
from colossalai.utils.checkpoint_io.meta import (
|
||||
ParamDistMeta,
|
||||
ParamRedistMeta,
|
||||
PipelineRedistMeta,
|
||||
RankRedistMeta,
|
||||
RedistMeta,
|
||||
)
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim(shard: bool = False, zero: bool = False):
|
||||
model = DummyModel()
|
||||
if shard:
|
||||
model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2]
|
||||
if zero:
|
||||
dp_rank = dist.get_rank() // 2
|
||||
model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank]
|
||||
if dp_rank != 0:
|
||||
model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype)
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def get_dist_metas(nprocs: int, zero: bool = False):
|
||||
dp_world_size = nprocs // 2
|
||||
dist_metas = []
|
||||
for rank in range(nprocs):
|
||||
if zero:
|
||||
dist_metas.append({
|
||||
'fc.weight':
|
||||
ParamDistMeta(rank // 2,
|
||||
dp_world_size,
|
||||
rank % 2,
|
||||
2,
|
||||
tp_shard_dims=[1],
|
||||
tp_num_parts=[2],
|
||||
zero_numel=10,
|
||||
zero_orig_shape=[1, 10]),
|
||||
'fc.bias':
|
||||
ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1])
|
||||
})
|
||||
else:
|
||||
dist_metas.append({
|
||||
'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1)
|
||||
})
|
||||
return dist_metas
|
||||
|
||||
|
||||
def get_redist_meta(nprocs: int):
|
||||
dp_world_size = nprocs // 2
|
||||
rank_meta = {
|
||||
'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)},
|
||||
'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)}
|
||||
}
|
||||
param_meta = {
|
||||
'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]),
|
||||
'fc.bias': ParamRedistMeta(dp_world_size, 1)
|
||||
}
|
||||
return RedistMeta(rank_meta, [], param_meta)
|
||||
|
||||
|
||||
def check_checkpoint_shape(dir_name: str):
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
for meta_name in global_meta['meta']:
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta['dist_meta'] is not None
|
||||
assert len(meta['params']) == 2
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
assert model_state_dict['fc.weight'].size(1) == 10
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict
|
||||
assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10
|
||||
assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10
|
||||
|
||||
|
||||
def test_global_to_dist():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
'size': 2
|
||||
}
|
||||
}},
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
model, optimizer = prepare_model_optim(shard=True, zero=zero)
|
||||
rank = dist.get_rank()
|
||||
save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("zero", [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_to_dist(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
if not zero:
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
else:
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_global_to_dist()
|
||||
test_dist_to_dist(False)
|
||||
test_dist_to_dist(True)
|
@@ -1,149 +0,0 @@
|
||||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import (
|
||||
GLOBAL_META_FILE_NAME,
|
||||
META_CKPT_FILE_NAME,
|
||||
MODEL_CKPT_FILE_NAME,
|
||||
OTHER_CKPT_FILE_NAME,
|
||||
)
|
||||
from colossalai.utils.checkpoint_io.io import save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
for k, v in a.items():
|
||||
assert torch.equal(v, b[k])
|
||||
|
||||
|
||||
def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None:
|
||||
assert set(a['state'].keys()) == set(b['state'].keys())
|
||||
for k, state in a['state'].items():
|
||||
b_state = b['state'][k]
|
||||
for v1, v2 in zip(state.values(), b_state.values()):
|
||||
if isinstance(v1, Tensor):
|
||||
assert torch.equal(v1, v2)
|
||||
else:
|
||||
assert v1 == v2
|
||||
if not ignore_param_groups:
|
||||
assert a['param_groups'] == b['param_groups']
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(20, 1)
|
||||
|
||||
|
||||
def prepare_model_optim():
|
||||
model = DummyModel()
|
||||
for p in model.parameters():
|
||||
p.grad = torch.ones_like(p)
|
||||
optimizer = Adam(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
return model, optimizer
|
||||
|
||||
|
||||
def test_overwrite():
|
||||
model = DummyModel()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f:
|
||||
pass
|
||||
with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'):
|
||||
save(dir_name, model)
|
||||
|
||||
|
||||
def test_save_global():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer)
|
||||
assert len(os.listdir(dir_name)) == 5
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 1
|
||||
assert len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
check_model_state_dict(model.state_dict(), model_state_dict)
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict)
|
||||
other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME))
|
||||
assert len(other_state_dict) == 0
|
||||
|
||||
|
||||
def test_save_global_shard():
|
||||
model, optimizer = prepare_model_optim()
|
||||
with TemporaryDirectory() as dir_name:
|
||||
save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3)
|
||||
assert len(os.listdir(dir_name)) == 7
|
||||
meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME))
|
||||
assert len(meta['model']) == 2 and len(meta['optimizer']) == 2
|
||||
model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']]
|
||||
assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0
|
||||
check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]})
|
||||
optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']]
|
||||
assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0
|
||||
assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1]
|
||||
check_optim_state_dict(
|
||||
optimizer.state_dict(), {
|
||||
'state': {
|
||||
**optimizer_state_dicts[0]['state'],
|
||||
**optimizer_state_dicts[1]['state']
|
||||
},
|
||||
'param_groups': optimizer_state_dicts[0]['param_groups']
|
||||
})
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name):
|
||||
model, optimizer = prepare_model_optim()
|
||||
dist_metas = {
|
||||
'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1),
|
||||
'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1)
|
||||
}
|
||||
save(dir_name, model, optimizer, dist_meta=dist_metas)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_save_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name)
|
||||
world_size = 2
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
assert len(os.listdir(dir_name)) == 8
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 2
|
||||
for rank, meta_name in enumerate(global_meta['meta']):
|
||||
meta = torch.load(os.path.join(dir_name, meta_name))
|
||||
assert meta.get('dist_meta', None) is not None
|
||||
assert len(meta['model']) == 1 and len(meta['optimizer']) == 1
|
||||
model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0]))
|
||||
assert len(model_state_dict) == 2
|
||||
optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0]))
|
||||
assert len(optimizer_state_dict['state']) == 2
|
||||
assert 'param_groups' in optimizer_state_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_overwrite()
|
||||
test_save_global()
|
||||
test_save_global_shard()
|
||||
test_save_dist()
|
@@ -1,137 +0,0 @@
|
||||
import torch
|
||||
from colossalai.utils.checkpoint_io.meta import ParamRedistMeta
|
||||
from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param
|
||||
|
||||
|
||||
def test_flatten_zero_param_even() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).chunk(4))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_flatten_zero_param_uneven() -> None:
|
||||
redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = list(orig_tensor.reshape(-1).split([13, 3]))
|
||||
flat_tensors = flatten_zero_param(orig_tensor, redist_meta)
|
||||
assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0
|
||||
flat_tensors = flat_tensors[1:-1]
|
||||
assert len(tensors) == len(flat_tensors)
|
||||
for t, st in zip(tensors, flat_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1
|
||||
unmerged_tensors = unmerged_tensors[0]
|
||||
assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0
|
||||
unmerged_tensors = unmerged_tensors[1:-1]
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert torch.equal(t, tl)
|
||||
|
||||
|
||||
def test_split_tp_param_1d_row() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_1d_col() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4])
|
||||
orig_tensor = torch.rand(4, 4)
|
||||
tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_split_tp_param_2d_reverse() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)]
|
||||
split_tensors = split_tp_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(split_tensors)
|
||||
for t, st in zip(tensors, split_tensors):
|
||||
assert torch.equal(t, st)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(tensors) == len(unmerged_tensors)
|
||||
for t, tl in zip(tensors, unmerged_tensors):
|
||||
assert len(tl) == 1
|
||||
assert torch.equal(t, tl[0])
|
||||
|
||||
|
||||
def test_unmerge_param_hybrid() -> None:
|
||||
redist_meta = ParamRedistMeta(2,
|
||||
6,
|
||||
tp_shard_dims=[1, 0],
|
||||
tp_num_parts=[3, 2],
|
||||
zero_start_dp_rank=0,
|
||||
zero_offsets=[0, 1])
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
tensors = [
|
||||
chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)
|
||||
for chunk in t.contiguous().reshape(-1).split([1, 3])
|
||||
]
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2
|
||||
for tp_rank in range(6):
|
||||
for dp_rank in range(2):
|
||||
assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank])
|
||||
|
||||
|
||||
def test_unmerge_param_dummy() -> None:
|
||||
redist_meta = ParamRedistMeta(1, 1)
|
||||
orig_tensor = torch.rand(4, 6)
|
||||
unmerged_tensors = unmerge_param(orig_tensor, redist_meta)
|
||||
assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1
|
||||
assert torch.equal(orig_tensor, unmerged_tensors[0][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_flatten_zero_param_even()
|
||||
test_flatten_zero_param_uneven()
|
||||
test_split_tp_param_1d_row()
|
||||
test_split_tp_param_1d_col()
|
||||
test_split_tp_param_2d()
|
||||
test_split_tp_param_2d_reverse()
|
||||
test_unmerge_param_hybrid()
|
||||
test_unmerge_param_dummy()
|
Reference in New Issue
Block a user