mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-05 12:24:38 +00:00
[test] refactor tests with spawn (#3452)
* [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -3,7 +3,6 @@ import copy
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
|
||||
import colossalai
|
||||
@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
# from colossalai.fx.passes.algorithms import solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import Sequence
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -26,8 +25,8 @@ except:
|
||||
withcodegen = False
|
||||
|
||||
|
||||
def _run_C_solver_consistency_test(rank=0):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_C_solver_consistency_test(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
|
||||
model = M()
|
||||
@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
|
||||
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_C_solver_consistency():
|
||||
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
|
||||
spawn(_run_C_solver_consistency_test, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
from torch.fx import GraphModule
|
||||
|
||||
@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
|
||||
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||
|
||||
|
||||
def _run_ckpt_solver(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
|
||||
|
||||
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
spawn(_run_ckpt_solver, 1)
|
||||
|
||||
|
||||
def _run_ckpt_solver_torch11(rank):
|
||||
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
def _run_ckpt_solver_torch11(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
MODEL_LIST = [tm.densenet121]
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
|
||||
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
spawn(_run_ckpt_solver_torch11, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
|
||||
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
|
||||
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from colossalai.fx.profiler.tensor import MetaTensor
|
||||
@@ -24,6 +25,7 @@ except:
|
||||
@pytest.mark.skip(reason='TODO: modify the logger')
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
@@ -84,6 +86,7 @@ def test_linearize():
|
||||
@pytest.mark.skip("TODO(lyl): refactor all tests.")
|
||||
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
|
||||
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
|
||||
@clear_cache_before_run()
|
||||
def test_linearize_torch11():
|
||||
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
|
||||
tracer = ColoTracer()
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
@@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
@@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
|
||||
|
||||
@pytest.mark.skip("this test failed")
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_perf():
|
||||
run_func = partial(run_dist, world_size=1, port=free_port())
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
spawn(run_dist, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -3,20 +3,20 @@ import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
|
||||
from colossalai.fx import ColoTracer, is_compatible_with_meta
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.auto_parallel.offload.region_manager import RegionManager
|
||||
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_offload.model_utils import *
|
||||
|
||||
|
||||
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_name', ['gpt2_', 'bert_'])
|
||||
@parameterize('memory_budget', [4000])
|
||||
@parameterize('solver_name', ['syn', 'asyn'])
|
||||
def solver_test(model_name: str,
|
||||
memory_budget: float,
|
||||
solver_name: str):
|
||||
def solver_test(model_name: str, memory_budget: float, solver_name: str):
|
||||
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, data_gen = get_components_func()
|
||||
@@ -52,11 +52,16 @@ def solver_test(model_name: str,
|
||||
for region in region_list:
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
print(
|
||||
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
|
||||
)
|
||||
for region in region_list.__reversed__():
|
||||
need_offload = region.need_offload
|
||||
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
|
||||
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
|
||||
print(
|
||||
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
solver_test()
|
||||
solver_test()
|
||||
|
||||
@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
|
||||
return gm
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def test_node_args_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_size_value_converting_pass():
|
||||
model = TestModule()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
@@ -13,9 +12,7 @@ except:
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
|
||||
class LinearModel(torch.nn.Module):
|
||||
@@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bias_addition_module():
|
||||
world_size = 4
|
||||
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_linear, nprocs=world_size)
|
||||
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_conv, nprocs=world_size)
|
||||
spawn(check_linear_module, 4)
|
||||
spawn(check_conv_module, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
@@ -17,9 +15,7 @@ except:
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
@@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer():
|
||||
world_size = 4
|
||||
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_act_ckpt, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
@@ -15,9 +13,7 @@ except:
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
@@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_compatibility_with_ddp():
|
||||
world_size = 4
|
||||
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_compatibility_with_ddp, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
@@ -17,10 +14,9 @@ from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
||||
@@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_auto_parallel_with_gemini():
|
||||
world_size = 4
|
||||
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_auto_parallel_with_gemini, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
# from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
|
||||
|
||||
NUM_REPEAT_BLOCKS = 4
|
||||
BATCH_SIZE = 1
|
||||
@@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
|
||||
def test_repeat_blocks(model_cls):
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import copy
|
||||
import random
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
@@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
BATCH_SIZE = 1
|
||||
@@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
|
||||
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_mlp_layer(model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_attention_layer, 4, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from torch.fx import GraphModule
|
||||
|
||||
@@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
|
||||
|
||||
@@ -20,6 +19,7 @@ HIDDEN_DIM = 384
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
|
||||
def test_self_attention_block(model_cls):
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)
|
||||
|
||||
@@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
@@ -26,6 +28,7 @@ class LinearModel(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skip('meta tensor has some bugs in 1.11')
|
||||
@clear_cache_before_run()
|
||||
def test_liveness_analysis():
|
||||
model = LinearModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -1,23 +1,14 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize('func', [
|
||||
torch.nn.functional.softmax,
|
||||
torch.nn.functional.relu,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
@@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_binary_elementwise_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
|
||||
return nn.functional.conv2d(input, self.conv_weight)
|
||||
|
||||
|
||||
def _conv_module_mem_test(rank, bias, world_size, port):
|
||||
def _conv_module_mem_test(rank, world_size, port, bias):
|
||||
"""This function is for conv memory test
|
||||
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
|
||||
|
||||
@@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_meta_concrete_info_match(bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_conv_module_mem_test, 4, bias=bias)
|
||||
|
||||
|
||||
def _conv_function_mem_test(rank, world_size, port):
|
||||
@@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_conv_function_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,33 +1,16 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_embedding_meta_info():
|
||||
meta_func = meta_register.get(torch.nn.Embedding)
|
||||
|
||||
|
||||
@@ -1,24 +1,14 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
||||
@@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_module_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_linear_module_mem_test, 4)
|
||||
|
||||
|
||||
def _linear_function_mem_test(rank, world_size, port):
|
||||
@@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_function_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_linear_function_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,26 +1,8 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
@@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
|
||||
@@ -1,29 +1,17 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
|
||||
from colossalai.auto_parallel.meta_profiler import meta_register
|
||||
|
||||
|
||||
def _batchnorm_module_mem_test(rank, world_size, port):
|
||||
@@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_batchnorm_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_batchnorm_module_mem_test, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_adaptiveavgpool_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_adaptiveavgpool_module_mem_test, 4)
|
||||
|
||||
|
||||
def _maxpool_module_mem_test(rank, world_size, port):
|
||||
@@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_maxpool_meta_concrete_info_match():
|
||||
world_size = 4
|
||||
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(_maxpool_module_mem_test, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,26 +1,9 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
@@ -37,6 +20,7 @@ class SplitModule(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_tensor_meta_info():
|
||||
"""test tensor related meta information
|
||||
We will just use torch.Tensor.split for the test
|
||||
|
||||
@@ -1,24 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
MemoryCost,
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
|
||||
from colossalai.testing.utils import clear_cache_before_run
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
|
||||
|
||||
if torch.__version__ >= '1.12.0':
|
||||
@@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
def test_where_meta_info():
|
||||
meta_func = meta_register.get(torch.where)
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
@@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
|
||||
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = module(using_kwargs).cuda()
|
||||
@@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_2d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
world_size=world_size,
|
||||
using_kwargs=using_kwargs,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_2d_device_mesh,
|
||||
4,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to bias cases not ready")
|
||||
@@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
|
||||
@parameterize('using_kwargs', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_1d_device_mesh(module, bias_shape, using_kwargs):
|
||||
world_size = 4
|
||||
run_func = partial(check_1d_device_mesh,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_1d_device_mesh,
|
||||
4,
|
||||
module=module,
|
||||
bias_shape=bias_shape,
|
||||
using_kwargs=using_kwargs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
|
||||
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if model_cls == AddmmModel:
|
||||
@@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
|
||||
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_addmm_handler(input_shape, model_cls):
|
||||
world_size = 4
|
||||
run_func_function = partial(check_addmm_function_handler,
|
||||
input_shape=input_shape,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bn_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_bn_module_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
WEIGHT_SHAPE = (32, 16)
|
||||
@@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler():
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(check_linear_module_handler)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
@@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, world_size, port):
|
||||
def check_linear_module_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModule(16, 32, bias=bias).cuda()
|
||||
@@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(bias=True):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
spawn(check_linear_module_handler, bias=bias)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
|
||||
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
@@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
|
||||
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
@@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
world_size = 4
|
||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||
spawn(
|
||||
check_binary_elementwise_handler_with_tensor,
|
||||
4,
|
||||
op=op,
|
||||
other_dim=other_dim,
|
||||
)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
|
||||
world_size = 4
|
||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_int, nprocs=world_size)
|
||||
spawn(
|
||||
check_binary_elementwise_handler_with_int,
|
||||
4,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bmm_handler(module):
|
||||
world_size = 4
|
||||
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_2d, nprocs=world_size)
|
||||
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func_1d, nprocs=world_size)
|
||||
spawn(check_2d_device_mesh, 4, module=module)
|
||||
spawn(check_1d_device_mesh, 4, module=module)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_conv_module_handler(rank, bias, world_size, port):
|
||||
def check_conv_module_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
|
||||
@@ -155,7 +150,7 @@ class ConvModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_conv_function_handler(rank, bias, world_size, port):
|
||||
def check_conv_function_handler(rank, world_size, port, bias):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = ConvModel().cuda()
|
||||
@@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
||||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_module_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_conv_module_handler, 4, bias=bias)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
|
||||
# @parameterize('bias', [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_conv_function_handler(bias=False):
|
||||
world_size = 4
|
||||
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_conv_function_handler, 4, bias=bias)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class ReshapeModel(nn.Module):
|
||||
@@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_reshape_handler():
|
||||
model = ReshapeModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
NUM_EMBEDDINGS = 16
|
||||
@@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_embedding_module_handler, 4)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_embedding_function_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_embedding_function_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class GetattrModel(nn.Module):
|
||||
@@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_getattr_handler():
|
||||
model = GetattrModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
|
||||
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
|
||||
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
|
||||
def test_getitem_from_tensor_handler(getitem_index):
|
||||
world_size = 4
|
||||
run_func = partial(check_getitem_from_tensor_handler,
|
||||
getitem_index=getitem_index,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_getitem_from_tensor_handler, 4)
|
||||
|
||||
|
||||
class GetItemFromTupleModel(nn.Module):
|
||||
@@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_getitem_from_tuple_handler():
|
||||
model = GetItemFromTupleModel()
|
||||
tracer = ColoTracer()
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_ln_module_handler():
|
||||
world_size = 4
|
||||
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_ln_module_handler, 4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
|
||||
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
|
||||
@@ -172,7 +168,7 @@ class LinearModel(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearModel().cuda()
|
||||
@@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_linear_handler(input_shape, bias=False):
|
||||
world_size = 4
|
||||
run_func_module = partial(check_linear_module_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_module, nprocs=world_size)
|
||||
run_func_function = partial(check_linear_function_handler,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_function, nprocs=world_size)
|
||||
spawn(
|
||||
check_linear_module_handler,
|
||||
4,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
spawn(
|
||||
check_linear_function_handler,
|
||||
4,
|
||||
bias=bias,
|
||||
input_shape=input_shape,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.testing.utils import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class MatMulModule(nn.Module):
|
||||
@@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
|
||||
@clear_cache_before_run()
|
||||
@parameterize(
|
||||
'tensor_shapes',
|
||||
[
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_norm_pool_handler():
|
||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
@@ -23,7 +23,7 @@ class OutputModel(nn.Module):
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('output_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_output_handler(output_option):
|
||||
model = OutputModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -2,7 +2,6 @@ from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
|
||||
return permute_node
|
||||
|
||||
|
||||
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
|
||||
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if call_function == torch.permute:
|
||||
@@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
|
||||
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
|
||||
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
|
||||
def test_view_handler(call_function, reshape_dims, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_view_handler,
|
||||
call_function=call_function,
|
||||
reshape_dims=reshape_dims,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(
|
||||
check_view_handler,
|
||||
4,
|
||||
call_function=call_function,
|
||||
reshape_dims=reshape_dims,
|
||||
model_cls=model_cls,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
@@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_placeholder_handler(placeholder_option):
|
||||
model = PlaceholderModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
@@ -108,6 +107,7 @@ def check_shard_option(shard_option):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_shard_option():
|
||||
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
|
||||
for shard_option in [ShardOption.SHARD_LAST_AXIS]:
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
|
||||
return softmax_node
|
||||
|
||||
|
||||
def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
||||
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = model_cls(softmax_dim=softmax_dim).cuda()
|
||||
@@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
||||
@parameterize('softmax_dim', [0, 1, 2, 3])
|
||||
@parameterize('model_cls', [LinearSplitModel])
|
||||
def test_split_handler(softmax_dim, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_split_handler,
|
||||
softmax_dim=softmax_dim,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
|
||||
return split_node
|
||||
|
||||
|
||||
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
|
||||
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
|
||||
@@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
||||
@parameterize('split_dim', [0, 1, 2])
|
||||
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
|
||||
def test_split_handler(split_size, split_dim, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_split_handler,
|
||||
split_size=split_size,
|
||||
split_dim=split_dim,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
|
||||
return sum_node
|
||||
|
||||
|
||||
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
|
||||
@@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
@parameterize('sum_dims', [(0, 2), 1])
|
||||
@parameterize('keepdim', [False, True])
|
||||
def test_sum_handler(sum_dims, keepdim):
|
||||
world_size = 4
|
||||
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class TensorConstructorModel(nn.Module):
|
||||
@@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_where_handler():
|
||||
model = TensorConstructorModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
class ReLuModel(nn.Module):
|
||||
@@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_elementwise_handler():
|
||||
model = ReLuModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
@@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
@@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
||||
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
|
||||
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
|
||||
def test_view_handler(tgt_shape, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_view_handler,
|
||||
tgt_shape=tgt_shape,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.testing import clear_cache_before_run
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
@@ -21,6 +22,7 @@ class ConvModel(nn.Module):
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@clear_cache_before_run()
|
||||
def test_where_handler():
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
@@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@clear_cache_before_run()
|
||||
def test_cost_graph():
|
||||
physical_mesh_id = torch.arange(0, 8)
|
||||
mesh_shape = (2, 4)
|
||||
|
||||
Reference in New Issue
Block a user