[hotfix] fix aten default bug (#2158)

This commit is contained in:
YuliangLiu0306
2022-12-20 22:40:46 +08:00
committed by GitHub
parent a4b4bb01d6
commit 16335cb537
10 changed files with 133 additions and 118 deletions

View File

@@ -207,9 +207,9 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
assert input_sharding_spec.sharding_sequence == output_sharding_spec.sharding_sequence
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('op', [torch.add])
@parameterize('other_dim', [1, 2])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_binary_elementwise_handler(op, other_dim):

View File

@@ -203,8 +203,8 @@ def check_1d_device_mesh(rank, module, world_size, port):
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bmm_handler(module):

View File

@@ -23,6 +23,7 @@ class GetItemFromTensorModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_getitem_from_tensor_handler():
model = GetItemFromTensorModel()
tracer = ColoTracer()
@@ -96,6 +97,7 @@ class GetItemFromTupleModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_getitem_from_tuple_handler():
model = GetItemFromTupleModel()
tracer = ColoTracer()

View File

@@ -308,8 +308,8 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_linear_handler(input_shape, bias=False):

View File

@@ -2,15 +2,15 @@ import pytest
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.testing.pytest_wrapper import run_on_environment_flag
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()

View File

@@ -20,6 +20,7 @@ class ReshapeModel(nn.Module):
return reshape_node
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_reshape_handler():
model = ReshapeModel()
tracer = ColoTracer()

View File

@@ -5,6 +5,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handl
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class TensorConstructorModel(nn.Module):
@@ -18,6 +19,7 @@ class TensorConstructorModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler():
model = TensorConstructorModel()
tracer = ColoTracer()

View File

@@ -22,6 +22,7 @@ class ReLuModel(nn.Module):
return relu_node
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_elementwise_handler():
model = ReLuModel()
tracer = ColoTracer()

View File

@@ -10,6 +10,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
def _param_resharding_cost_assertion(node):
@@ -51,6 +52,7 @@ class ConvModel(torch.nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_linear_module():
model = LinearModel(4, 8)
physical_mesh_id = torch.arange(0, 4)
@@ -86,6 +88,7 @@ def test_linear_module():
_param_resharding_cost_assertion(linear_node)
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_module():
model = ConvModel(3, 6, 2)
physical_mesh_id = torch.arange(0, 4)