[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -4,8 +4,10 @@
import pytest
import torch
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils.activation_checkpoint import checkpoint
@@ -39,8 +41,9 @@ def forward_inplace(x, weight):
@pytest.mark.gpu
@pytest.mark.parametrize("use_reentrant", [True, False])
@pytest.mark.parametrize("cpu_offload", [True, False])
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
def test_activation_checkpointing(cpu_offload, use_reentrant):
# as seed manager is singleton

View File

@@ -1,80 +1,77 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
from functools import partial
import colossalai.nn as col_nn
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_1d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_1d():
world_size = 8
run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == "__main__":
test_checkpoint_1d()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_1d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_1d():
spawn(check_checkpoint_1d, 8)
if __name__ == "__main__":
test_checkpoint_1d()

View File

@@ -1,80 +1,77 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
from functools import partial
import colossalai.nn as col_nn
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_2d():
world_size = 8
run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == "__main__":
test_checkpoint_2d()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_2d():
spawn(check_checkpoint_2d, 8)
if __name__ == "__main__":
test_checkpoint_2d()

View File

@@ -1,80 +1,77 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
from functools import partial
import colossalai.nn as col_nn
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2p5d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_2p5d():
world_size = 8
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == "__main__":
test_checkpoint_2p5d()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_2p5d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_2p5d():
spawn(check_checkpoint_2p5d, 8)
if __name__ == "__main__":
test_checkpoint_2p5d()

View File

@@ -1,80 +1,77 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
from functools import partial
import colossalai.nn as col_nn
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device, is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_3d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_checkpoint_3d():
world_size = 8
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == "__main__":
test_checkpoint_3d()
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import pprint
import pytest
import torch
import torch.nn as nn
import colossalai.nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
from colossalai.utils import is_using_pp
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
def build_pipeline(model):
from colossalai.pipeline.utils import partition_uniform
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
depth = len(model)
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
layers = []
for i in range(depth):
if start <= i < end:
layers.append(model[i])
else:
layers.append(nn.Identity())
return nn.Sequential(*tuple(layers))
def check_equal(A, B):
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
def check_checkpoint_3d(rank, world_size, port):
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
disable_existing_loggers()
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
sd1 = m1.state_dict()
if gpc.get_global_rank() == 0:
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
save_checkpoint("test.pt", 0, m1)
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
if is_using_pp():
m2 = build_pipeline(m2)
load_checkpoint("test.pt", m2)
sd2 = m2.state_dict()
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
sd2 = gather_pipeline_parallel_state_dict(sd2)
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
if gpc.get_global_rank() == 0:
for k, v in sd1.items():
assert k in sd2
check_equal(v, sd2[k].to(torch.device("cpu")))
@pytest.mark.dist
@pytest.mark.skip("takes too long")
@skip_if_not_enough_gpus(min_gpus=8)
@rerun_if_address_is_in_use()
def test_checkpoint_3d():
spawn(check_checkpoint_3d, 8)
if __name__ == "__main__":
test_checkpoint_3d()

View File

@@ -3,20 +3,19 @@ from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta)
from torch import Tensor
from torch.nn import Module
from torch.optim import Adam, Optimizer
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.io import load, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
@@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float):
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
def run_dist(rank, world_size, port, func):
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func()
test_fn()
def launch_dist(fn, world_size: int):
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
spawn(run_dist, world_size, test_fn=fn)
def save_dist(dir_name: str, zero: bool):

View File

@@ -1,18 +1,18 @@
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import save, merge
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tempfile import TemporaryDirectory
from torch.optim import Adam
from functools import partial
import torch
import os
from functools import partial
from tempfile import TemporaryDirectory
import pytest
import colossalai
import torch.nn as nn
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import merge, save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
class DummyModel(nn.Module):
@@ -52,7 +52,7 @@ def test_merge_global():
assert len(os.listdir(output_dir)) == 0
def run_dist(rank, world_size, port, func):
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func):
host='localhost',
port=port,
backend='nccl')
func()
test_fn()
def run_save_dist(dir_name: str, zero: bool):
@@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
spawn(run_dist, world_size, test_fn=fn)
with TemporaryDirectory() as output_dir:
merge(dir_name, output_dir)
assert len(os.listdir(output_dir)) == 5

View File

@@ -2,19 +2,23 @@ import os
from functools import partial
from tempfile import TemporaryDirectory
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
from colossalai.utils.checkpoint_io.io import redist, save
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta,
RedistMeta)
from torch.optim import Adam
from colossalai.utils.checkpoint_io.meta import (
ParamDistMeta,
ParamRedistMeta,
PipelineRedistMeta,
RankRedistMeta,
RedistMeta,
)
class DummyModel(nn.Module):
@@ -105,7 +109,7 @@ def test_global_to_dist():
check_checkpoint_shape(output_dir)
def run_dist(rank, world_size, port, func):
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={'parallel': {
'tensor': {
'mode': '1d',
@@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func):
host='localhost',
port=port,
backend='nccl')
func()
test_fn()
def run_save_dist(dir_name: str, zero: bool):
@@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool):
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name, zero)
world_size = 4
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
spawn(run_dist, world_size, test_fn=fn)
with TemporaryDirectory() as output_dir:
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
if not zero:

View File

@@ -3,21 +3,24 @@ from functools import partial
from tempfile import TemporaryDirectory
from typing import Dict
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from torch import Tensor
from torch.optim import Adam
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint_io.constant import (
GLOBAL_META_FILE_NAME,
META_CKPT_FILE_NAME,
MODEL_CKPT_FILE_NAME,
OTHER_CKPT_FILE_NAME,
)
from colossalai.utils.checkpoint_io.io import save
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
assert set(a.keys()) == set(b.keys())
@@ -104,9 +107,9 @@ def test_save_global_shard():
})
def run_dist(rank, world_size, port, func):
def run_dist(rank, world_size, port, test_fn):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
func()
test_fn()
def run_save_dist(dir_name):
@@ -124,8 +127,7 @@ def test_save_dist():
with TemporaryDirectory() as dir_name:
fn = partial(run_save_dist, dir_name)
world_size = 2
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
mp.spawn(proc_fn, nprocs=world_size)
spawn(run_dist, world_size, test_fn=fn)
assert len(os.listdir(dir_name)) == 8
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
assert len(global_meta['meta']) == 2

View File

@@ -1,20 +1,17 @@
import os
import shutil
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
import colossalai
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
@@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
run_func = partial(run_dist,
world_size=world_size,
port=free_port(),
use_ddp=use_ddp,
use_mp_reload=use_mp_reload,
test_scheduler=test_scheduler)
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
if __name__ == '__main__':

View File

@@ -1,15 +1,13 @@
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.zero.legacy.sharded_param import ShardedTensor
def run_tensor_move(rank):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
def run_tensor_move(rank, world_size, port):
colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
src_t = torch.ones(2, 3).cuda()
tgt_t = torch.zeros(2, 3)
@@ -36,7 +34,7 @@ def run_tensor_move(rank):
@rerun_if_address_is_in_use()
def test_tensor_move():
mp.spawn(run_tensor_move, nprocs=1)
spawn(run_tensor_move, 1)
if __name__ == '__main__':

View File

@@ -5,6 +5,7 @@ import torch
from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN:
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
@@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
@@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
@@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
@clear_cache_before_run()
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD
@@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
@clear_cache_before_run()
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
D = H * D_HEAD

View File

@@ -1,17 +1,14 @@
from functools import partial
from typing import Optional
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.common import print_rank_0
try:
@@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None:
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_lazy_init():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, 4)
if __name__ == '__main__':

View File

@@ -1,12 +1,9 @@
import pytest
import colossalai
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
from colossalai.utils import free_port
from functools import partial
import torch.multiprocessing as mp
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
@@ -24,8 +21,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [3, 4])
def test_memory_utils(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,16 +1,15 @@
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers
from colossalai.utils import free_port, get_current_device
from torch.nn.utils import clip_grad_norm_
from functools import partial
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.common import clip_grad_norm
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.utils.common import clip_grad_norm
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
@@ -71,8 +70,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_zero_clip_grad(world_size: int):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':

View File

@@ -1,21 +1,19 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import checkpoint, clip_grad_norm_fp32
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
@@ -106,8 +104,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_zero_clip_grad():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(run_dist, world_size)
if __name__ == '__main__':