[test] refactor tests with spawn (#3452)

* [test] added spawn decorator

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-04-06 14:51:35 +08:00
committed by GitHub
parent 62f4e2eb07
commit 80eba05b0a
240 changed files with 1723 additions and 2342 deletions

View File

@@ -3,7 +3,6 @@ import copy
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
import colossalai
@@ -13,7 +12,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
# from colossalai.fx.passes.algorithms import solver_rotor
# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -26,8 +25,8 @@ except:
withcodegen = False
def _run_C_solver_consistency_test(rank=0):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_C_solver_consistency_test(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]:
model = M()
@@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0):
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
@rerun_if_address_is_in_use()
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
spawn(_run_C_solver_consistency_test, 1)
if __name__ == '__main__':

View File

@@ -4,7 +4,6 @@ from typing import Callable
import pytest
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from torch.fx import GraphModule
@@ -15,7 +14,7 @@ from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, spawn
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@@ -98,12 +97,13 @@ def _run_ckpt_solver(rank):
@pytest.mark.skip("TODO(super-dainiu): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@rerun_if_address_is_in_use()
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
spawn(_run_ckpt_solver, 1)
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
def _run_ckpt_solver_torch11(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MODEL_LIST = [tm.densenet121]
torch.backends.cudnn.deterministic = True
@@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank):
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done")
@rerun_if_address_is_in_use()
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
spawn(_run_ckpt_solver_torch11, 1)
if __name__ == '__main__':

View File

@@ -8,6 +8,7 @@ from colossalai.fx.graph_module import ColoGraphModule
# from colossalai.fx.passes.algorithms import linearize, solver_rotor
# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.testing import clear_cache_before_run
if is_compatible_with_meta():
from colossalai.fx.profiler.tensor import MetaTensor
@@ -24,6 +25,7 @@ except:
@pytest.mark.skip(reason='TODO: modify the logger')
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
@clear_cache_before_run()
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()
@@ -84,6 +86,7 @@ def test_linearize():
@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
@clear_cache_before_run()
def test_linearize_torch11():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
tracer = ColoTracer()

View File

@@ -1,9 +1,7 @@
import time
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.utils._pytree import tree_map
import colossalai
@@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed
@@ -140,9 +138,9 @@ def run_dist(rank, world_size, port):
@pytest.mark.skip("this test failed")
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@rerun_if_address_is_in_use()
def test_perf():
run_func = partial(run_dist, world_size=1, port=free_port())
mp.spawn(run_func, nprocs=1)
spawn(run_dist, 1)
if __name__ == '__main__':

View File

@@ -3,20 +3,20 @@ import torch.fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_map
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.auto_parallel.offload.region_manager import RegionManager
from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_offload.model_utils import *
@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
@clear_cache_before_run()
@parameterize('model_name', ['gpt2_', 'bert_'])
@parameterize('memory_budget', [4000])
@parameterize('solver_name', ['syn', 'asyn'])
def solver_test(model_name: str,
memory_budget: float,
solver_name: str):
def solver_test(model_name: str, memory_budget: float, solver_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen = get_components_func()
@@ -52,11 +52,16 @@ def solver_test(model_name: str,
for region in region_list:
need_offload = region.need_offload
to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
for region in region_list.__reversed__():
need_offload = region.need_offload
to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
print(
f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}'
)
if __name__ == '__main__':
solver_test()
solver_test()

View File

@@ -6,6 +6,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
@@ -26,6 +27,7 @@ def insert_narrow(gm, x_node):
return gm
@clear_cache_before_run()
def test_node_args_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)

View File

@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import clear_cache_before_run
class TestModule(torch.nn.Module):
@@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_size_value_converting_pass():
model = TestModule()
physical_mesh_id = torch.arange(0, 4)

View File

@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@@ -13,9 +12,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class LinearModel(torch.nn.Module):
@@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bias_addition_module():
world_size = 4
run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port())
mp.spawn(run_func_linear, nprocs=world_size)
run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port())
mp.spawn(run_func_conv, nprocs=world_size)
spawn(check_linear_module, 4)
spawn(check_conv_module, 4)
if __name__ == '__main__':

View File

@@ -1,9 +1,7 @@
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from transformers.pytorch_utils import Conv1D
@@ -17,9 +15,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
HIDDEN_SIZE = 16
@@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_mlp_layer():
world_size = 4
run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_act_ckpt, 4)
if __name__ == '__main__':

View File

@@ -1,9 +1,7 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
@@ -15,9 +13,7 @@ except:
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
class MLP(torch.nn.Module):
@@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_compatibility_with_ddp():
world_size = 4
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_compatibility_with_ddp, 4)
if __name__ == '__main__':

View File

@@ -1,10 +1,7 @@
import copy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
try:
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
@@ -17,10 +14,9 @@ from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device
from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from colossalai.utils import get_current_device
from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module):
@@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_auto_parallel_with_gemini():
world_size = 4
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_auto_parallel_with_gemini, 4)
if __name__ == '__main__':

View File

@@ -10,8 +10,7 @@ from colossalai._analyzer.fx.passes import shape_prop_pass
# from colossalai.fx.tracer.tracer import ColoTracer
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.testing import parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag
NUM_REPEAT_BLOCKS = 4
BATCH_SIZE = 1
@@ -81,6 +80,7 @@ class NonRepeatModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [RepeatModel, NonRepeatModel])
def test_repeat_blocks(model_cls):

View File

@@ -1,12 +1,10 @@
import copy
import random
from functools import partial
from typing import Dict
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
import transformers
from torch.fx import GraphModule
@@ -30,9 +28,8 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
BATCH_SIZE = 1
@@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
@rerun_if_address_is_in_use()
def test_mlp_layer(model_cls):
world_size = 4
run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_attention_layer, 4, model_cls=model_cls)
if __name__ == '__main__':

View File

@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import transformers
from torch.fx import GraphModule
@@ -7,10 +6,10 @@ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing import parameterize
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
@@ -20,6 +19,7 @@ HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM)

View File

@@ -6,6 +6,8 @@ from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing import clear_cache_before_run
class LinearModel(nn.Module):
@@ -26,6 +28,7 @@ class LinearModel(nn.Module):
@pytest.mark.skip('meta tensor has some bugs in 1.11')
@clear_cache_before_run()
def test_liveness_analysis():
model = LinearModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -1,23 +1,14 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.meta_profiler import meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize('func', [
torch.nn.functional.softmax,
torch.nn.functional.relu,

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
@@ -10,8 +7,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_binary_elementwise_mem_test, 4)
if __name__ == '__main__':

View File

@@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@@ -25,7 +20,7 @@ class ConvFunctionModule(nn.Module):
return nn.functional.conv2d(input, self.conv_weight)
def _conv_module_mem_test(rank, bias, world_size, port):
def _conv_module_mem_test(rank, world_size, port, bias):
"""This function is for conv memory test
Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
@@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_meta_concrete_info_match(bias=False):
world_size = 4
run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_module_mem_test, 4, bias=bias)
def _conv_function_mem_test(rank, world_size, port):
@@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_conv_function_concrete_info_match():
world_size = 4
run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_conv_function_mem_test, 4)
if __name__ == '__main__':

View File

@@ -1,33 +1,16 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_embedding_meta_info():
meta_func = meta_register.get(torch.nn.Embedding)

View File

@@ -1,24 +1,14 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
class MyModule(nn.Module):
@@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_module_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_module_mem_test, 4)
def _linear_function_mem_test(rank, world_size, port):
@@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_function_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_linear_function_mem_test, 4)
if __name__ == '__main__':

View File

@@ -1,26 +1,8 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run, parameterize
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@@ -28,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[

View File

@@ -1,29 +1,17 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
if torch.__version__ >= '1.12.0':
from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register
from colossalai.auto_parallel.meta_profiler import meta_register
def _batchnorm_module_mem_test(rank, world_size, port):
@@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_batchnorm_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_batchnorm_module_mem_test, 4)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')

View File

@@ -1,17 +1,12 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.testing.utils import rerun_if_address_is_in_use, spawn
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
@@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_adaptiveavgpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_adaptiveavgpool_module_mem_test, 4)
def _maxpool_module_mem_test(rank, world_size, port):
@@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_maxpool_meta_concrete_info_match():
world_size = 4
run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(_maxpool_module_mem_test, 4)
if __name__ == '__main__':

View File

@@ -1,26 +1,9 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@@ -37,6 +20,7 @@ class SplitModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_tensor_meta_info():
"""test tensor related meta information
We will just use torch.Tensor.split for the test

View File

@@ -1,24 +1,8 @@
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem
from colossalai.testing.utils import clear_cache_before_run
from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
if torch.__version__ >= '1.12.0':
@@ -26,6 +10,7 @@ if torch.__version__ >= '1.12.0':
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
def test_where_meta_info():
meta_func = meta_register.get(torch.where)

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
@@ -11,9 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -45,7 +40,7 @@ class AddBMMTorchFunctionModule(nn.Module):
return output
def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port):
def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = module(using_kwargs).cuda()
@@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_2d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_2d_device_mesh,
module=module,
bias_shape=bias_shape,
world_size=world_size,
using_kwargs=using_kwargs,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_2d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
@pytest.mark.skip("skip due to bias cases not ready")
@@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs):
@parameterize('using_kwargs', [True, False])
@rerun_if_address_is_in_use()
def test_1d_device_mesh(module, bias_shape, using_kwargs):
world_size = 4
run_func = partial(check_1d_device_mesh,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_1d_device_mesh,
4,
module=module,
bias_shape=bias_shape,
using_kwargs=using_kwargs,
)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -17,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -45,7 +40,7 @@ class AddmmModel_with_param(nn.Module):
return x
def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port):
def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if model_cls == AddmmModel:
@@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
@parameterize('model_cls', [AddmmModel, AddmmModel_with_param])
@rerun_if_address_is_in_use()
def test_addmm_handler(input_shape, model_cls):
world_size = 4
run_func_function = partial(check_addmm_function_handler,
input_shape=input_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bn_module_handler():
world_size = 4
run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_bn_module_handler, 4)
if __name__ == '__main__':

View File

@@ -1,9 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -19,9 +15,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
WEIGHT_SHAPE = (32, 16)
@@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler():
world_size = 4
run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler)
if __name__ == '__main__':

View File

@@ -1,14 +1,10 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@@ -18,9 +14,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -35,7 +29,7 @@ class LinearModule(torch.nn.Module):
return x
def check_linear_module_handler(rank, bias, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModule(16, 32, bias=bias).cuda()
@@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(bias=True):
world_size = 4
run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
spawn(check_linear_module_handler, bias=bias)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port):
def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@@ -149,7 +144,7 @@ class BEOpModelWithIntConst(nn.Module):
return out
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_tensor,
4,
op=op,
other_dim=other_dim,
)
@run_on_environment_flag(name='AUTO_PARALLEL')
@@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
mp.spawn(run_func_int, nprocs=world_size)
spawn(
check_binary_elementwise_handler_with_int,
4,
op=op,
model_cls=model_cls,
other_dim=other_dim,
)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -13,9 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bmm_handler(module):
world_size = 4
run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_2d, nprocs=world_size)
run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port())
mp.spawn(run_func_1d, nprocs=world_size)
spawn(check_2d_device_mesh, 4, module=module)
spawn(check_1d_device_mesh, 4, module=module)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -13,13 +10,11 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_conv_module_handler(rank, bias, world_size, port):
def check_conv_module_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
@@ -155,7 +150,7 @@ class ConvModel(nn.Module):
return x
def check_conv_function_handler(rank, bias, world_size, port):
def check_conv_function_handler(rank, world_size, port, bias):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = ConvModel().cuda()
@@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_module_handler(bias=False):
world_size = 4
run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_module_handler, 4, bias=bias)
@run_on_environment_flag(name='AUTO_PARALLEL')
@@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False):
# @parameterize('bias', [True, False])
@rerun_if_address_is_in_use()
def test_conv_function_handler(bias=False):
world_size = 4
run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_conv_function_handler, 4, bias=bias)
if __name__ == '__main__':

View File

@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHan
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReshapeModel(nn.Module):
@@ -23,6 +23,7 @@ class ReshapeModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -16,9 +13,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
NUM_EMBEDDINGS = 16
@@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_module_handler():
world_size = 4
run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_module_handler, 4)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_embedding_function_handler():
world_size = 4
run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_embedding_function_handler, 4)
if __name__ == '__main__':

View File

@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class GetattrModel(nn.Module):
@@ -22,6 +23,7 @@ class GetattrModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_getattr_handler():
model = GetattrModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import Li
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))])
def test_getitem_from_tensor_handler(getitem_index):
world_size = 4
run_func = partial(check_getitem_from_tensor_handler,
getitem_index=getitem_index,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_getitem_from_tensor_handler, 4)
class GetItemFromTupleModel(nn.Module):
@@ -123,6 +115,7 @@ class GetItemFromTupleModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel()
tracer = ColoTracer()

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -11,12 +8,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_ln_module_handler():
world_size = 4
run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_ln_module_handler, 4)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -18,14 +15,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
def check_linear_module_handler(rank, bias, input_shape, world_size, port):
def check_linear_module_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
@@ -172,7 +168,7 @@ class LinearModel(nn.Module):
return x
def check_linear_function_handler(rank, bias, input_shape, world_size, port):
def check_linear_function_handler(rank, world_size, port, bias, input_shape):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearModel().cuda()
@@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False):
world_size = 4
run_func_module = partial(check_linear_module_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
run_func_function = partial(check_linear_function_handler,
bias=bias,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
spawn(
check_linear_module_handler,
4,
bias=bias,
input_shape=input_shape,
)
spawn(
check_linear_function_handler,
4,
bias=bias,
input_shape=input_shape,
)
if __name__ == '__main__':

View File

@@ -18,7 +18,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.utils import parameterize
from colossalai.testing.utils import clear_cache_before_run, parameterize
class MatMulModule(nn.Module):
@@ -28,6 +28,7 @@ class MatMulModule(nn.Module):
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@clear_cache_before_run()
@parameterize(
'tensor_shapes',
[

View File

@@ -1,4 +1,3 @@
import pytest
import torch
import torch.nn as nn
@@ -8,11 +7,11 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class OutputModel(nn.Module):
@@ -23,7 +23,7 @@ class OutputModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('output_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_output_handler(output_option):
model = OutputModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -2,7 +2,6 @@ from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -15,9 +14,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -55,7 +53,7 @@ class LinearReshapeModel(nn.Module):
return permute_node
def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port):
def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if call_function == torch.permute:
@@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
@parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))])
@parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel])
def test_view_handler(call_function, reshape_dims, model_cls):
world_size = 4
run_func = partial(check_view_handler,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(
check_view_handler,
4,
call_function=call_function,
reshape_dims=reshape_dims,
model_cls=model_cls,
)
if __name__ == '__main__':

View File

@@ -8,7 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import clear_cache_before_run, parameterize
class PlaceholderModel(nn.Module):
@@ -22,7 +22,7 @@ class PlaceholderModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@parameterize('placeholder_option', ['distributed', 'replicated'])
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_placeholder_handler(placeholder_option):
model = PlaceholderModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -1,5 +1,4 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHan
from colossalai.auto_parallel.tensor_shard.options import ShardOption
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class LinearModel(nn.Module):
@@ -108,6 +107,7 @@ def check_shard_option(shard_option):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_shard_option():
# for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
for shard_option in [ShardOption.SHARD_LAST_AXIS]:

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -33,7 +28,7 @@ class LinearSplitModel(nn.Module):
return softmax_node
def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, softmax_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(softmax_dim=softmax_dim).cuda()
@@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
@parameterize('softmax_dim', [0, 1, 2, 3])
@parameterize('model_cls', [LinearSplitModel])
def test_split_handler(softmax_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
softmax_dim=softmax_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -15,9 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -47,7 +42,7 @@ class LinearSplitModel(nn.Module):
return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
@@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
@parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls)
if __name__ == '__main__':

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -14,9 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -36,7 +31,7 @@ class LinearSumModel(nn.Module):
return sum_node
def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
def check_sum_handler(rank, world_size, port, sum_dims, keepdim):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda()
@@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
@parameterize('sum_dims', [(0, 2), 1])
@parameterize('keepdim', [False, True])
def test_sum_handler(sum_dims, keepdim):
world_size = 4
run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim)
if __name__ == '__main__':

View File

@@ -7,7 +7,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class TensorConstructorModel(nn.Module):
@@ -22,6 +22,7 @@ class TensorConstructorModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_where_handler():
model = TensorConstructorModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import Conv
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
class ReLuModel(nn.Module):
@@ -24,6 +24,7 @@ class ReLuModel(nn.Module):
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -1,8 +1,5 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai._analyzer.fx.graph_module import ColoGraphModule
@@ -15,9 +12,8 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDat
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
@@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
def test_view_handler(tgt_shape, model_cls):
world_size = 4
run_func = partial(check_view_handler,
tgt_shape=tgt_shape,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls)
if __name__ == '__main__':

View File

@@ -8,6 +8,7 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing import clear_cache_before_run
class ConvModel(nn.Module):
@@ -21,6 +22,7 @@ class ConvModel(nn.Module):
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
@clear_cache_before_run()
def test_where_handler():
model = ConvModel()
tracer = ColoTracer(bias_addition_split=True)

View File

@@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing import clear_cache_before_run, run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
@clear_cache_before_run()
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)