mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -4,8 +4,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.random import add_seed, seed, set_mode, reset_seeds
|
||||
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from colossalai.utils.activation_checkpoint import checkpoint
|
||||
|
||||
|
||||
@@ -39,8 +41,9 @@ def forward_inplace(x, weight):
|
||||
|
||||
|
||||
@pytest.mark.gpu
|
||||
@pytest.mark.parametrize("use_reentrant", [True, False])
|
||||
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||
@clear_cache_before_run()
|
||||
@parameterize("use_reentrant", [True, False])
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
def test_activation_checkpointing(cpu_offload, use_reentrant):
|
||||
|
||||
# as seed manager is singleton
|
||||
|
@@ -1,80 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_1d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_checkpoint_1d():
|
||||
world_size = 8
|
||||
run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_1d()
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_1d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_1d():
|
||||
spawn(check_checkpoint_1d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_1d()
|
||||
|
@@ -1,80 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_checkpoint_2d():
|
||||
world_size = 8
|
||||
run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2d()
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_2d():
|
||||
spawn(check_checkpoint_2d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2d()
|
||||
|
@@ -1,80 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2p5d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_checkpoint_2p5d():
|
||||
world_size = 8
|
||||
run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2p5d()
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_2p5d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_2p5d():
|
||||
spawn(check_checkpoint_2p5d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_2p5d()
|
||||
|
@@ -1,80 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
from functools import partial
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device, is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_3d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_checkpoint_3d():
|
||||
world_size = 8
|
||||
run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_3d()
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint
|
||||
|
||||
|
||||
def build_pipeline(model):
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
depth = len(model)
|
||||
start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0]
|
||||
layers = []
|
||||
for i in range(depth):
|
||||
if start <= i < end:
|
||||
layers.append(model[i])
|
||||
else:
|
||||
layers.append(nn.Identity())
|
||||
return nn.Sequential(*tuple(layers))
|
||||
|
||||
|
||||
def check_equal(A, B):
|
||||
assert torch.allclose(A, B, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def check_checkpoint_3d(rank, world_size, port):
|
||||
config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),)
|
||||
|
||||
disable_existing_loggers()
|
||||
launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
|
||||
m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4))
|
||||
sd1 = m1.state_dict()
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n")
|
||||
save_checkpoint("test.pt", 0, m1)
|
||||
|
||||
m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4))
|
||||
if is_using_pp():
|
||||
m2 = build_pipeline(m2)
|
||||
|
||||
load_checkpoint("test.pt", m2)
|
||||
sd2 = m2.state_dict()
|
||||
if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
sd2 = gather_pipeline_parallel_state_dict(sd2)
|
||||
print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n")
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for k, v in sd1.items():
|
||||
assert k in sd2
|
||||
check_equal(v, sd2[k].to(torch.device("cpu")))
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("takes too long")
|
||||
@skip_if_not_enough_gpus(min_gpus=8)
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint_3d():
|
||||
spawn(check_checkpoint_3d, 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_checkpoint_3d()
|
||||
|
@@ -3,20 +3,19 @@ from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.checkpoint_io.io import load, save
|
||||
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta)
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.io import load, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
@@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float):
|
||||
check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
func()
|
||||
test_fn()
|
||||
|
||||
|
||||
def launch_dist(fn, world_size: int):
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
|
||||
|
||||
def save_dist(dir_name: str, zero: bool):
|
||||
|
@@ -1,18 +1,18 @@
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import save, merge
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tempfile import TemporaryDirectory
|
||||
from torch.optim import Adam
|
||||
from functools import partial
|
||||
import torch
|
||||
import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import colossalai
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import merge, save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
@@ -52,7 +52,7 @@ def test_merge_global():
|
||||
assert len(os.listdir(output_dir)) == 0
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
@@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func):
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
func()
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
@@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
merge(dir_name, output_dir)
|
||||
assert len(os.listdir(output_dir)) == 5
|
||||
|
@@ -2,19 +2,23 @@ import os
|
||||
from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME
|
||||
from colossalai.utils.checkpoint_io.io import redist, save
|
||||
from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta,
|
||||
RedistMeta)
|
||||
from torch.optim import Adam
|
||||
from colossalai.utils.checkpoint_io.meta import (
|
||||
ParamDistMeta,
|
||||
ParamRedistMeta,
|
||||
PipelineRedistMeta,
|
||||
RankRedistMeta,
|
||||
RedistMeta,
|
||||
)
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
@@ -105,7 +109,7 @@ def test_global_to_dist():
|
||||
check_checkpoint_shape(output_dir)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={'parallel': {
|
||||
'tensor': {
|
||||
'mode': '1d',
|
||||
@@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func):
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
func()
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name: str, zero: bool):
|
||||
@@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool):
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name, zero)
|
||||
world_size = 4
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
with TemporaryDirectory() as output_dir:
|
||||
redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4))
|
||||
if not zero:
|
||||
|
@@ -3,21 +3,24 @@ from functools import partial
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME,
|
||||
OTHER_CKPT_FILE_NAME)
|
||||
from colossalai.utils.checkpoint_io.io import save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint_io.constant import (
|
||||
GLOBAL_META_FILE_NAME,
|
||||
META_CKPT_FILE_NAME,
|
||||
MODEL_CKPT_FILE_NAME,
|
||||
OTHER_CKPT_FILE_NAME,
|
||||
)
|
||||
from colossalai.utils.checkpoint_io.io import save
|
||||
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
|
||||
|
||||
|
||||
def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None:
|
||||
assert set(a.keys()) == set(b.keys())
|
||||
@@ -104,9 +107,9 @@ def test_save_global_shard():
|
||||
})
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
def run_dist(rank, world_size, port, test_fn):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
func()
|
||||
test_fn()
|
||||
|
||||
|
||||
def run_save_dist(dir_name):
|
||||
@@ -124,8 +127,7 @@ def test_save_dist():
|
||||
with TemporaryDirectory() as dir_name:
|
||||
fn = partial(run_save_dist, dir_name)
|
||||
world_size = 2
|
||||
proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn)
|
||||
mp.spawn(proc_fn, nprocs=world_size)
|
||||
spawn(run_dist, world_size, test_fn=fn)
|
||||
assert len(os.listdir(dir_name)) == 8
|
||||
global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME))
|
||||
assert len(global_meta['meta']) == 2
|
||||
|
@@ -1,20 +1,17 @@
|
||||
import os
|
||||
import shutil
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
|
||||
|
||||
import colossalai
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext
|
||||
@@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
|
||||
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
|
||||
run_func = partial(run_dist,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
use_ddp=use_ddp,
|
||||
use_mp_reload=use_mp_reload,
|
||||
test_scheduler=test_scheduler)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,15 +1,13 @@
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.legacy.sharded_param import ShardedTensor
|
||||
|
||||
|
||||
def run_tensor_move(rank):
|
||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def run_tensor_move(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
src_t = torch.ones(2, 3).cuda()
|
||||
tgt_t = torch.zeros(2, 3)
|
||||
@@ -36,7 +34,7 @@ def run_tensor_move(rank):
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_tensor_move():
|
||||
mp.spawn(run_tensor_move, nprocs=1)
|
||||
spawn(run_tensor_move, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -5,6 +5,7 @@ import torch
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
if HAS_MEM_EFF_ATTN:
|
||||
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
|
||||
@@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
@@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
@@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
|
||||
def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
@@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
|
||||
|
||||
|
||||
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
|
||||
@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
|
||||
@clear_cache_before_run()
|
||||
@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
|
||||
def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
|
||||
D = H * D_HEAD
|
||||
|
||||
|
@@ -1,17 +1,14 @@
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.common import print_rank_0
|
||||
|
||||
try:
|
||||
@@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None:
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_lazy_init():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,12 +1,9 @@
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from functools import partial
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
|
||||
|
||||
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
|
||||
@@ -24,8 +21,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [3, 4])
|
||||
def test_memory_utils(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,16 +1,15 @@
|
||||
from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from functools import partial
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils.common import clip_grad_norm
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.common import clip_grad_norm
|
||||
|
||||
|
||||
def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8):
|
||||
@@ -71,8 +70,7 @@ def run_dist(rank, world_size, port):
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad(world_size: int):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -1,21 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import checkpoint, clip_grad_norm_fp32
|
||||
from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
@@ -106,8 +104,7 @@ def run_dist(rank, world_size, port):
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_clip_grad():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user