mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 03:03:37 +00:00
[autoparallel] adapt autoparallel with new analyzer (#3261)
* [autoparallel] adapt autoparallel with new analyzer * fix all node handler tests * polish * polish
This commit is contained in:
@@ -1,22 +1,20 @@
|
||||
from faulthandler import disable
|
||||
from functools import partial
|
||||
from xml.dom import WrongDocumentErr
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
ShardingStrategy,
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
@@ -96,7 +94,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
|
||||
meta_arg_names=meta_arg_names,
|
||||
node_type='bias_module')
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %m1 : torch.Tensor [#users=1] = placeholder[target=m1]
|
||||
@@ -109,6 +107,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port)
|
||||
# return add
|
||||
graph = tracer.trace(model, meta_args=meta_args_for_tracer)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args_for_tracer.values())
|
||||
# [input_1, m1, m2, addmm, output]
|
||||
node_list = list(graph.nodes)
|
||||
linear_node = node_list[4]
|
||||
|
@@ -5,10 +5,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -38,13 +40,15 @@ def check_bn_module_handler(rank, world_size, port):
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
|
||||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 64, 64).to('meta')})
|
||||
meta_args = {"input": torch.rand(4, 16, 64, 64).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
bn_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(bn_mod_node)
|
||||
|
||||
|
@@ -1,14 +1,14 @@
|
||||
from faulthandler import disable
|
||||
from functools import partial
|
||||
from xml.dom import WrongDocumentErr
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
@@ -17,12 +17,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
@@ -66,7 +64,7 @@ def check_linear_module_handler(rank, world_size, port):
|
||||
meta_arg_names=meta_arg_names,
|
||||
node_type='bias_module')
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %weight : [#users=1] = get_attr[target=weight]
|
||||
@@ -74,8 +72,10 @@ def check_linear_module_handler(rank, world_size, port):
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %weight), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %bias), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')})
|
||||
meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from faulthandler import disable
|
||||
from functools import partial
|
||||
from xml.dom import WrongDocumentErr
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing_extensions import Self
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
@@ -16,12 +16,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
@@ -62,9 +60,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||
meta_arg_names=meta_arg_names,
|
||||
node_type='bias_module')
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {"x": torch.rand(4, 4, 4, 16).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
linear_mod_node = list(graph.nodes)[3]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
@@ -5,10 +5,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BinaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -52,10 +54,11 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta'), 'x2': torch.rand([4] * other_dim).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
op_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
@@ -172,12 +175,11 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
# assert False
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
if model_cls == BEOpModelWithNodeConst:
|
||||
op_node = list(graph.nodes)[2]
|
||||
|
@@ -5,10 +5,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -52,13 +54,11 @@ def check_2d_device_mesh(rank, module, world_size, port):
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
linear_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
@@ -147,13 +147,11 @@ def check_1d_device_mesh(rank, module, world_size, port):
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"x1": torch.rand(4, 8, 16).to('meta'),
|
||||
'x2': torch.rand(4, 16, 8).to('meta')
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'x1': torch.rand(4, 8, 16).to('meta'), 'x2': torch.rand(4, 16, 8).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
linear_mod_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
|
||||
@@ -205,6 +203,7 @@ def check_1d_device_mesh(rank, module, world_size, port):
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
@parameterize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bmm_handler(module):
|
||||
|
@@ -5,10 +5,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -41,9 +43,11 @@ def check_conv_module_handler(rank, bias, world_size, port):
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
conv_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(conv_mod_node)
|
||||
|
||||
@@ -178,7 +182,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
||||
meta_arg_names=meta_arg_names,
|
||||
input_kwargs=input_kwargs)
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
||||
@@ -189,6 +193,7 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
||||
meta_args['bias'] = torch.rand(16).to('meta')
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
if bias:
|
||||
conv_mod_node = list(graph.nodes)[3]
|
||||
|
@@ -1,11 +1,13 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@@ -23,19 +25,20 @@ class ReshapeModel(nn.Module):
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_reshape_handler():
|
||||
model = ReshapeModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"other": torch.rand(4, 16, 3, 3).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"other": torch.rand(16, 4, 3, 3).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
@@ -67,13 +70,13 @@ def test_reshape_handler():
|
||||
|
||||
assert mapping['input'].name == "conv2d"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
|
||||
|
||||
assert mapping['output'].name == "view"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([2, 30752])
|
||||
assert mapping['output'].data.shape == torch.Size([2, 123008])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
|
@@ -5,13 +5,15 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.embedding_handler import (
|
||||
EmbeddingFunctionHandler,
|
||||
EmbeddingModuleHandler,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -60,9 +62,11 @@ def check_embedding_module_handler(rank, world_size, port):
|
||||
input_args=[input],
|
||||
meta_arg_names=['input'])
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16, 16).to('meta')})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
embedding_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(embedding_node)
|
||||
|
||||
@@ -171,18 +175,19 @@ def check_embedding_function_handler(rank, world_size, port):
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names,
|
||||
input_kwargs=input_kwargs)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %others : torch.Tensor [#users=1] = placeholder[target=others]
|
||||
# %embedding : [#users=1] = call_function[target=torch.nn.functional.embedding](args = (%input_1, %others), kwargs = {padding_idx: None, max_norm: None, norm_type: 2.0, scale_grad_by_freq: False, sparse: False})
|
||||
# return embedding
|
||||
meta_args = {
|
||||
"input": torch.rand(4, 16, 16).to('meta'),
|
||||
"input": torch.randint(NUM_EMBEDDINGS, (4, 16, 16)).to('meta'),
|
||||
"others": torch.rand(NUM_EMBEDDINGS, EMBEDDING_DIMS).to('meta')
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
embedding_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(embedding_node)
|
||||
|
@@ -1,10 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class GetattrModel(nn.Module):
|
||||
@@ -18,15 +21,18 @@ class GetattrModel(nn.Module):
|
||||
return weight
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
def test_getattr_handler():
|
||||
model = GetattrModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=0] = placeholder[target=input]
|
||||
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
|
||||
# return conv_weight
|
||||
graph = tracer.trace(model, meta_args={'input': torch.rand(4, 4, 64, 64).to('meta')})
|
||||
meta_args = {'input': torch.rand(4, 4, 64, 64).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
@@ -5,13 +5,15 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -58,15 +60,15 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port):
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
|
||||
tracer = ColoTracer()
|
||||
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *list(meta_args.values()))
|
||||
linear_mod_node = list(graph.nodes)[2]
|
||||
getitem_mod_node = list(graph.nodes)[3]
|
||||
getitem_strategies_vector = StrategiesVector(getitem_mod_node)
|
||||
@@ -129,10 +131,12 @@ def test_getitem_from_tuple_handler():
|
||||
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
|
||||
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
|
||||
# return getitem
|
||||
graph = tracer.trace(model, meta_args={
|
||||
meta_args = {
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
})
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -5,10 +5,12 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
@@ -40,13 +42,15 @@ def check_ln_module_handler(rank, world_size, port):
|
||||
strategy_number=strategy_number,
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
|
||||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
|
||||
meta_args = {"input": torch.rand(4, 16).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
ln_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(ln_mod_node)
|
||||
|
@@ -5,6 +5,9 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
@@ -13,7 +16,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -49,9 +51,11 @@ def check_linear_module_handler(rank, bias, input_shape, world_size, port):
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {"input": torch.rand(input_shape).cuda()}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
linear_mod_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(linear_mod_node)
|
||||
@@ -196,13 +200,12 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port):
|
||||
input_args=input_args,
|
||||
meta_arg_names=meta_arg_names)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(input_shape).to('meta'),
|
||||
'others': torch.rand(32, 16).to('meta')
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'input': torch.rand(input_shape).to('meta'), 'others': torch.rand(32, 16).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
if bias:
|
||||
linear_func_node = list(graph.nodes)[3]
|
||||
else:
|
||||
|
@@ -2,6 +2,9 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.matmul_handler import (
|
||||
MatMulHandler,
|
||||
MatMulType,
|
||||
@@ -15,7 +18,6 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
StrategiesVector,
|
||||
)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
@@ -57,9 +59,11 @@ def test_matmul_node_handler(tensor_shapes):
|
||||
|
||||
model = MatMulModule()
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model, meta_args={"x1": x1.to('meta'), 'x2': x2.to('meta')})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {"x1": x1.to('meta'), 'x2': x2.to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
print(graph)
|
||||
@@ -124,7 +128,6 @@ def test_matmul_node_handler(tensor_shapes):
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x1')
|
||||
other_sharding_spec = strategy.get_sharding_spec_by_name('x2')
|
||||
output_sharding_spec = strategy.get_sharding_spec_by_name('matmul')
|
||||
|
||||
if matmul_type == MatMulType.DOT:
|
||||
# dot product will produce a scaler
|
||||
# results should fulfill:
|
||||
@@ -159,7 +162,10 @@ def test_matmul_node_handler(tensor_shapes):
|
||||
if len(other_shape) > 1:
|
||||
assert other_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
|
||||
if len(input_shape) > 1:
|
||||
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
|
||||
if len(other_shape) == 1:
|
||||
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-1]
|
||||
else:
|
||||
assert input_sharding_spec.sharding_sequence[-2] == output_sharding_spec.sharding_sequence[-2]
|
||||
if len(other_shape) > 2:
|
||||
assert other_sharding_spec.sharding_sequence[-2] == input_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
@@ -2,10 +2,12 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
@@ -13,14 +15,16 @@ from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_norm_pool_handler():
|
||||
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
|
||||
# return _0
|
||||
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
|
||||
meta_args = {"input": torch.rand(4, 4, 64, 64).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -1,10 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
@@ -18,19 +21,20 @@ class OutputModel(nn.Module):
|
||||
return x, y
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('output_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_output_handler(output_option):
|
||||
model = OutputModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=2] = placeholder[target=x]
|
||||
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
|
||||
# return (x, mul)
|
||||
graph = tracer.trace(model, meta_args={
|
||||
"x": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
})
|
||||
meta_args = {'x': torch.rand(4, 4, 64, 64).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -5,12 +5,14 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -88,7 +90,7 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
|
||||
input_args=[input, other],
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
if model_cls.__name__ == 'ConvReshapeModel':
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -96,11 +98,11 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %permute : [#users=1] = call_function[target=torch.permute](args = (%conv2d, (0, 2, 1, 3)), kwargs = {})
|
||||
# return permute
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 8, 66, 66).to('meta'),
|
||||
'other': torch.rand(16, 8, 3, 3).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
if model_cls.__name__ == 'LinearReshapeModel':
|
||||
# graph():
|
||||
@@ -109,13 +111,14 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %permute : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
|
||||
# return permute
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 16, 64, 32).to('meta'),
|
||||
'other': torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
reshape_node = list(graph.nodes)[3]
|
||||
|
@@ -1,10 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
|
||||
|
||||
@@ -17,18 +20,21 @@ class PlaceholderModel(nn.Module):
|
||||
return input
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
@parameterize('placeholder_option', ['distributed', 'replicated'])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_placeholder_handler(placeholder_option):
|
||||
model = PlaceholderModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# return input_1
|
||||
graph = tracer.trace(model, meta_args={
|
||||
meta_args = {
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
})
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -1,17 +1,15 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.options import ShardOption
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
@@ -30,13 +28,11 @@ def check_shard_option(shard_option):
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 4, 16).to('meta'),
|
||||
'others': torch.rand(32, 16).to('meta')
|
||||
})
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
meta_args = {'input': torch.rand(4, 4, 4, 16).to('meta'), 'others': torch.rand(32, 16).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
linear_func_node = list(graph.nodes)[2]
|
||||
strategies_vector = StrategiesVector(linear_func_node)
|
||||
|
||||
|
@@ -6,11 +6,13 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.softmax_handler import SoftmaxHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -54,7 +56,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
||||
input_args=[input, other],
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -62,13 +64,14 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %softmax : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
|
||||
# return split
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 16, 64, 32).to('meta'),
|
||||
'other': torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
split_node = list(graph.nodes)[3]
|
||||
|
@@ -5,12 +5,14 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -76,7 +78,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
||||
input_args=[input, other],
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
if model_cls.__name__ == 'ConvSplitModel':
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -84,11 +86,11 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {})
|
||||
# return split
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 8, 66, 66).to('meta'),
|
||||
'other': torch.rand(16, 8, 3, 3).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
if model_cls.__name__ == 'LinearSplitModel':
|
||||
# graph():
|
||||
@@ -97,13 +99,14 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
|
||||
# return split
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 16, 64, 32).to('meta'),
|
||||
'other': torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
split_node = list(graph.nodes)[3]
|
||||
|
@@ -5,12 +5,13 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.sum_handler import SumHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -58,7 +59,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -66,12 +67,13 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%linear,), kwargs = {})
|
||||
# return sum_1
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
sum_node = list(graph.nodes)[3]
|
||||
@@ -116,107 +118,107 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port):
|
||||
|
||||
# check strategy name
|
||||
if sum_dims == (0, 2) and keepdim == False:
|
||||
assert '[R, R, R, S1] -> [R, S1]_0' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> [S0, S1]_1' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_2' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_3' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> [S1, S0]_4' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_5' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [S0, R]_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_0' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [S01, R]_1' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_2' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_3' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, S01]_4' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_5' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_6' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_9' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [S1, R]_10' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_9' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_10' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_11' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> [S0, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_13' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_14' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> [S1, S0]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, S0]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, S1]_17' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_18' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [S01, R]_19' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_17' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [S0, R]_18' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_19' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, S01]_22' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [S1, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_22' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R]_23' in strategy_name_list
|
||||
|
||||
if sum_dims == (0, 2) and keepdim == True:
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_2' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_5' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_2' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_13' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_22' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
|
||||
|
||||
if sum_dims == 1 and keepdim == False:
|
||||
assert '[S0, R, R, S1] -> [S0, R, S1]_0' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, S1]_1' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, S0, S1]_2' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, S0]_3' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_4' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, S1, S0]_5' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R]_6' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R]_0' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_1' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, S01, R]_2' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_3' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, S01]_4' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, S1]_5' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_6' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, S0, R]_8' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R]_9' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_10' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, S1, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_9' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, S1]_10' in strategy_name_list
|
||||
assert '[S0, R, R, S1] -> [S0, R, S1]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, S1]_17' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R]_18' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_19' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, S01, R]_20' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, S0, S1]_13' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, S0]_14' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, S0]_15' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, S1, S0]_16' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R]_17' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_18' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, S0, R]_19' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, S01]_22' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, S1, R]_22' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R]_23' in strategy_name_list
|
||||
|
||||
if sum_dims == 1 and keepdim == True:
|
||||
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
|
||||
assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R]_23' in strategy_name_list
|
||||
|
||||
|
||||
|
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@@ -22,7 +24,7 @@ class TensorConstructorModel(nn.Module):
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_where_handler():
|
||||
model = TensorConstructorModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=2] = placeholder[target=x]
|
||||
# %size : [#users=1] = call_method[target=size](args = (%x,), kwargs = {})
|
||||
@@ -30,10 +32,10 @@ def test_where_handler():
|
||||
# %arange : [#users=1] = call_function[target=torch.arange](args = (%getitem,), kwargs = {})
|
||||
# %add : [#users=1] = call_function[target=operator.add](args = (%x, %arange), kwargs = {})
|
||||
# return add
|
||||
graph = tracer.trace(model, meta_args={
|
||||
"x": torch.rand(10).to('meta'),
|
||||
})
|
||||
meta_args = {'x': torch.rand(10).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
@@ -25,19 +26,20 @@ class ReLuModel(nn.Module):
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
def test_elementwise_handler():
|
||||
model = ReLuModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
|
||||
# return act
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"other": torch.rand(4, 16, 3, 3).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(4, 4, 64, 64).to('meta'),
|
||||
'other': torch.rand(16, 4, 3, 3).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
@@ -69,13 +71,13 @@ def test_elementwise_handler():
|
||||
|
||||
assert mapping['input'].name == "conv2d"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([4, 4, 62, 62])
|
||||
assert mapping['input'].data.shape == torch.Size([4, 16, 62, 62])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 4, 62, 62])
|
||||
assert mapping['input'].logical_shape == torch.Size([4, 16, 62, 62])
|
||||
|
||||
assert mapping['output'].name == "act"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([4, 4, 62, 62])
|
||||
assert mapping['output'].data.shape == torch.Size([4, 16, 62, 62])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
|
@@ -5,12 +5,14 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
@@ -74,7 +76,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
||||
input_args=[input, other],
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
@@ -82,11 +84,8 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
meta_args = {'input': torch.rand(8, 8, 66, 66).to('meta'), 'other': torch.rand(16, 8, 3, 3).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
if model_cls.__name__ == 'LinearViewModel':
|
||||
# graph():
|
||||
@@ -95,13 +94,14 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
meta_args = {
|
||||
'input': torch.rand(8, 16, 64, 32).to('meta'),
|
||||
'other': torch.rand(64, 32).to('meta'),
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
view_node = list(graph.nodes)[3]
|
||||
|
@@ -1,12 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
|
||||
WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
@@ -19,22 +20,24 @@ class ConvModel(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
||||
def test_where_handler():
|
||||
model = ConvModel()
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
# graph():
|
||||
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %y : torch.Tensor [#users=1] = placeholder[target=y]
|
||||
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
|
||||
# return where
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"condition": torch.rand(4, 4, 64, 64).to('meta'),
|
||||
"x": torch.rand(4, 1, 64, 64).to('meta'),
|
||||
"y": torch.rand(1, 4, 64, 64).to('meta')
|
||||
})
|
||||
meta_args = {
|
||||
'condition': torch.rand(4, 4, 64, 64).to('meta'),
|
||||
'x': torch.rand(4, 1, 64, 64).to('meta'),
|
||||
'y': torch.rand(1, 4, 64, 64).to('meta')
|
||||
}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
shape_prop_pass(gm, *meta_args.values())
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
|
@@ -4,6 +4,9 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
||||
from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
|
||||
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverOptions
|
||||
@@ -11,7 +14,6 @@ from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import to_global
|
||||
from colossalai.testing.comparison import assert_close
|
||||
|
||||
@@ -79,14 +81,16 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
||||
model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
|
||||
grad_to_shard_dict)
|
||||
|
||||
tracer = ColoTracer()
|
||||
tracer = ColoTracer(bias_addition_split=True)
|
||||
input_sample = {}
|
||||
for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
|
||||
input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
|
||||
input_sample[meta_arg_name] = torch.empty(input_arg.shape, dtype=input_arg.dtype).to('meta')
|
||||
for meta_kwarg_name, input_kwarg in input_kwargs.items():
|
||||
input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
|
||||
input_sample[meta_kwarg_name] = torch.empty(input_kwarg.shape, dtype=input_kwarg.dtype).to('meta')
|
||||
graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
|
||||
gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
|
||||
shape_prop_pass(gm, *input_sample.values())
|
||||
|
||||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
Reference in New Issue
Block a user