mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[autoparallel] refactored the autoparallel module for organization (#1706)
* [autoparallel] refactored the autoparallel module for organization * polish code
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (get_broadcast_shape, is_broadcastable,
|
||||
recover_sharding_spec_for_broadcast_shape)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
def test_is_broadcastable():
|
||||
|
@@ -1,7 +1,8 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.batch_norm_handler import \
|
||||
BatchNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_bn_module_handler():
|
||||
|
@@ -1,10 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import \
|
||||
BMMFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
@@ -1,10 +1,11 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import (ConvFunctionHandler, ConvModuleHandler)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_conv_module_handler():
|
||||
|
@@ -1,11 +1,14 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import \
|
||||
GetItemHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class GetItemModel(nn.Module):
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import \
|
||||
LayerNormModuleHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
def test_ln_module_handler():
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.dot_handler import (LinearFunctionHandler, LinearModuleHandler)
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
|
||||
StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
|
@@ -1,11 +1,13 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import \
|
||||
NormPoolingHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
import pytest
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import \
|
||||
OuputHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class OutputModel(nn.Module):
|
||||
|
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import \
|
||||
PlacehodlerHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class PlaceholderModel(nn.Module):
|
||||
|
@@ -1,10 +1,13 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import \
|
||||
ReshapeHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
|
||||
|
||||
class ReshapeModel(nn.Module):
|
||||
|
@@ -1,11 +1,14 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
|
||||
ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
|
||||
UnaryElementwiseHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class ReLuModel(nn.Module):
|
||||
|
@@ -1,10 +1,12 @@
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \
|
||||
WhereHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
|
@@ -1,24 +1,22 @@
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.solver.solver import Solver_V2
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (shape_consistency_pass,
|
||||
solution_annotatation_pass)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
class ConvModel(nn.Module):
|
||||
@@ -61,7 +59,7 @@ def check_apply(rank, world_size, port):
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
@@ -1,20 +1,13 @@
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from torchvision.models import resnet50
|
||||
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
|
||||
from colossalai.auto_parallel.tensor_shard.solver import (CostGraph, GraphAnalyser, Solver, SolverOptions,
|
||||
StrategiesConstructor)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.solver.solver import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.solver.constants import *
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.solver.options import SolverOptions
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user