mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[autoparallel] collated all deprecated files (#1700)
* [autoparallel] collated all deprecated files * polish code
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.op_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -6,9 +6,9 @@ import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from copy import deepcopy
|
||||
|
||||
|
@@ -6,8 +6,8 @@ import pytest
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
@@ -3,8 +3,8 @@ from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -3,8 +3,8 @@ from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -6,8 +6,8 @@ import pytest
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
@@ -6,8 +6,8 @@ import pytest
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
@@ -2,13 +2,13 @@ import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.auto_parallel.solver import sharding_strategy
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy
|
||||
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
|
@@ -3,8 +3,8 @@ from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -3,8 +3,8 @@ from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -9,15 +9,15 @@ from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
@@ -6,12 +6,12 @@ import pytest
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
@@ -4,17 +4,17 @@ import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
import transformers
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
|
||||
BATCH_SIZE = 8
|
||||
SEQ_LENGHT = 8
|
@@ -4,17 +4,17 @@ import torch.nn as nn
|
||||
import pytest
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
|
||||
|
||||
class MLP(torch.nn.Module):
|
@@ -6,11 +6,11 @@ import pytest
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
|
||||
from copy import deepcopy
|
||||
|
||||
|
@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2 import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.layer_norm_handler_v2 import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy_V2
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
@@ -83,7 +83,7 @@ def test_linear_module_handler():
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy_V2
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
||||
@@ -164,7 +164,7 @@ def test_linear_function_handler():
|
||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||
|
||||
for strategy in strategies_vector:
|
||||
strategy: ShardingStrategy_V2
|
||||
strategy: ShardingStrategy
|
||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
|
||||
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
|
@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import pytest
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler_V2
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
@@ -48,9 +48,9 @@ def test_reshape_handler():
|
||||
strategies_vector=conv_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||
reshape_handler = ReshapeHandler_V2(node=reshape_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=reshape_strategies_vector)
|
||||
reshape_handler = ReshapeHandler(node=reshape_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=reshape_strategies_vector)
|
||||
|
||||
reshape_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
|
||||
from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
@@ -50,9 +50,9 @@ def test_elementwise_handler():
|
||||
strategies_vector=conv_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||
relu_handler = UnaryElementwiseHandler_V2(node=relu_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=relu_strategies_vector)
|
||||
relu_handler = UnaryElementwiseHandler(node=relu_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=relu_strategies_vector)
|
||||
|
||||
relu_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.op_handler.where_handler_v2 import WhereHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
@@ -7,10 +7,10 @@ from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||
from colossalai.auto_parallel.solver.solver import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
@@ -60,12 +60,12 @@ def test_cost_graph():
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_list = graph_analyser.liveness_analysis()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies)
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
|
||||
ret = solver.call_solver_serialized_args()
|
||||
print(ret[0])
|
Reference in New Issue
Block a user