diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index a5f418be4..bbbcc8cb5 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -18,174 +18,6 @@ class CostGraph: simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) ''' - def __init__(self, leaf_strategies, simplify=True): - self.leaf_strategies = leaf_strategies - self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] - # stores number of strategies in each node - self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} - # extra_node_costs will store the extra costs introduced by merging nodes - self.extra_node_costs = {} - self.following_dict = {} - self.simplify = simplify - self._build_cost_graph() - - def _remove_invalid_node(self, node, attr_name): - remove_list = [] - target_node_list = getattr(node, attr_name, []) - for target_node in target_node_list: - if target_node not in self.nodes: - remove_list.append(target_node) - for element in remove_list: - target_node_list.remove(element) - - def _build_cost_graph(self): - ''' - This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be - set to node. - ''' - self.edge_costs = {} - if self.simplify: - self.merge_pair = [] - for strategies_vector in self.leaf_strategies: - # build edge_cost - dst_node = strategies_vector.node - for src_node in strategies_vector.predecessor_nodes: - if src_node not in self.nodes: - continue - node_pair = (src_node, dst_node) - # src_index = strategies_vector.predecessor_nodes.index(src_node) - edge_cost = {} - for i in range(len(strategies_vector)): - for j in range(len(src_node.strategies_vector)): - edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j] - self.edge_costs[node_pair] = edge_cost - # add parents and children attribute to node - setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) - setattr(dst_node, 'children', strategies_vector.successor_nodes) - self._remove_invalid_node(dst_node, 'parents') - self._remove_invalid_node(dst_node, 'children') - - if self.simplify and strategies_vector.check_merge(): - for followed_node in strategies_vector.predecessor_nodes: - self.merge_pair.append((followed_node, dst_node)) - - def get_edge_cost(self, src_node, dst_node): - return self.edge_costs[(src_node, dst_node)] - - def merge_node(self, src_node, dst_node): - ''' - To merge dst_node into src_node, we need to do it in following steps: - - 1. For each strategy in dst_node, we need to pick an appropriate strategy - of src_node to merge, it is important because the logical resharding costs - between the parents node of src_node and merged node depend on the src_node - strategies dispatching. For example, for the graph 0->1->2, after merging node 1 - into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] - x represents the picking strategy of node 1 merged into node 2 strategy 0. - - 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs - contains two parts, one is resharding costs between src_node strategy and dst_node strategy, - another is the origin extra costs in src_node strategy. - - 3. Build connections between new node pairs, and remove the src_node after all consumer nodes - detached from it. - - Argument: - src_node(Node): The node will be merged into dst_node. - dst_node(Node): The node to integrate src_node. - ''' - src_node_index = dst_node.parents.index(src_node) - # build merge_map - merge_map = {} - for src_index, strategy in enumerate(src_node.strategies_vector): - min_cost = INFINITY_COST - lowest_cost_index = -1 - for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): - resharding_cost = dst_strategy.resharding_costs[src_node][src_index] - if resharding_cost <= min_cost: - min_cost = resharding_cost - lowest_cost_index = dst_index - merge_map[src_index] = lowest_cost_index - - # extra_node_cost for src node - self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] - for src_index, strategy in enumerate(src_node.strategies_vector): - target_strate_index = merge_map[src_index] - target_strategy = dst_node.strategies_vector[target_strate_index] - self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index] - if dst_node in self.extra_node_costs: - self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] - - # add new node pair to cost graph - for child_node in dst_node.children: - new_node_pair = (src_node, child_node) - old_node_pair = (dst_node, child_node) - if new_node_pair in self.edge_costs: - continue - edge_cost = {} - for i in range(self.node_lens[src_node]): - for j in range(self.node_lens[child_node]): - dst_strate_index = merge_map[i] - # dst_strategy = dst_node.strategies_vector[dst_strate_index] - edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] - if new_node_pair not in self.edge_costs: - self.edge_costs[new_node_pair] = edge_cost - else: - # we should accumulate the resharding costs if args of child node contain - # both src node and dst node. - for index_pair, resharding_cost in self.edge_costs[new_node_pair]: - self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] - - # connect src node and children of dst node - dst_node.parents.remove(src_node) - src_node.children.remove(dst_node) - self.edge_costs.pop((src_node, dst_node)) - for child_node in dst_node.children: - if child_node not in src_node.children: - src_node.children.append(child_node) - if src_node not in child_node.parents: - child_node.parents.append(src_node) - # remove dst node from cost graph when dst node has no producer. - if len(dst_node.parents) == 0: - child_node.parents.remove(dst_node) - node_pair = (dst_node, child_node) - self.edge_costs.pop(node_pair) - if len(dst_node.parents) == 0: - self.following_dict[dst_node] = src_node - dst_node.children = [] - - def _reindexing_src(self, src): - if src not in self.following_dict: - return src - return self._reindexing_src(self.following_dict[src]) - - def simplify_graph(self): - if not self.simplify: - return - self.merge_pair.reverse() - for (src_node, dst_node) in self.merge_pair: - self.merge_node(src_node, dst_node) - self.merge_pair.reverse() - reindexing_following_dict = {} - for dst, src in self.following_dict.items(): - reindexing_following_dict[dst] = self._reindexing_src(src) - self.following_dict = reindexing_following_dict - - -class CostGraph_V2: - ''' - A graph data structure to simplify the edge cost graph. It has two main functions: - 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in - CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as - element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will - be given by the StrategiesVector depending on the type of target node and following nodes. - - Argument: - leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. - simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) - ''' - def __init__(self, leaf_strategies, simplify=True, forward_only=False): self.leaf_strategies = leaf_strategies self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] diff --git a/colossalai/auto_parallel/solver/node_handler/__init__.py b/colossalai/auto_parallel/solver/node_handler/__init__.py new file mode 100644 index 000000000..9aad0b91a --- /dev/null +++ b/colossalai/auto_parallel/solver/node_handler/__init__.py @@ -0,0 +1,16 @@ +from .dot_handler import LinearFunctionHandler, LinearModuleHandler +from .layer_norm_handler import LayerNormModuleHandler +from .batch_norm_handler import BatchNormModuleHandler +from .conv_handler import ConvModuleHandler, ConvFunctionHandler +from .where_handler import WhereHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .reshape_handler import ReshapeHandler +from .placeholder_handler import PlacehodlerHandler +from .output_handler import OuputHandler +from .normal_pooling_handler import NormPoolingHandler + +__all__ = [ + 'LinearFunctionHandler', 'LinearModuleHandler', 'LayerNormModuleHandler', 'BatchNormModuleHandler', + 'ConvModuleHandler', 'ConvFunctionHandler', 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', + 'OuputHandler', 'WhereHandler', 'NormPoolingHandler' +] diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py similarity index 92% rename from colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py index 185327c94..4a5e0fdec 100644 --- a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/batch_norm_handler.py @@ -1,8 +1,8 @@ import torch import torch.nn.functional as F from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from ..strategy import BatchNormStrategyGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from ..strategy import BatchNormStrategyGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry @@ -17,7 +17,7 @@ class BatchNormModuleHandler(ModuleHandler): A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(BatchNormStrategyGenerator(op_data_mapping, self.device_mesh)) diff --git a/colossalai/auto_parallel/solver/op_handler/broadcast.py b/colossalai/auto_parallel/solver/node_handler/broadcast.py similarity index 100% rename from colossalai/auto_parallel/solver/op_handler/broadcast.py rename to colossalai/auto_parallel/solver/node_handler/broadcast.py diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/conv_handler.py similarity index 94% rename from colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/conv_handler.py index 7085c3d2b..2074eeb1b 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/conv_handler.py @@ -1,8 +1,8 @@ import torch import torch.nn.functional as F from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from ..strategy import ConvStrategyGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from ..strategy import ConvStrategyGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry @@ -17,7 +17,7 @@ class ConvModuleHandler(ModuleHandler): A ConvModuleHandler which deals with the sharding strategies for nn.Convxd module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) @@ -47,7 +47,7 @@ class ConvModuleHandler(ModuleHandler): mapping['bias'] = physical_bias_operand return mapping - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy): """ Convert the sharding spec of the weight parameter back to its original shape. """ @@ -78,7 +78,7 @@ class ConvFunctionHandler(NodeHandler): A ConvFunctionHandler which deals with the sharding strategies for nn.functional.ConvXd functions. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(ConvStrategyGenerator(op_data_mapping, self.device_mesh)) @@ -120,7 +120,7 @@ class ConvFunctionHandler(NodeHandler): mapping['bias'] = physical_bias_operand return mapping - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy): """ Convert the sharding spec of the weight parameter back to its original shape. """ diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/dot_handler.py similarity index 95% rename from colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/dot_handler.py index d18a6e88a..015f71ebd 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/dot_handler.py @@ -2,8 +2,8 @@ import torch import torch.nn.functional as F from colossalai.tensor.sharding_spec import ShardingException from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator, BatchedMatMulStrategyGenerator from typing import List, Dict, Union from .registry import operator_registry from copy import deepcopy @@ -18,7 +18,7 @@ class LinearModuleHandler(ModuleHandler): A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) @@ -53,7 +53,7 @@ class LinearModuleHandler(ModuleHandler): mapping['bias'] = physical_bias_operand return mapping - def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: """ Convert the sharding spec from the logical shape to the physical shape. """ @@ -101,7 +101,7 @@ class LinearFunctionHandler(NodeHandler): A LinearModuleHandler which deals with the sharding strategies for nn.Linear module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh)) @@ -140,7 +140,7 @@ class LinearFunctionHandler(NodeHandler): mapping['bias'] = physical_bias_operand return mapping - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy): """ Convert the sharding spec of the weight parameter back to its original shape. """ @@ -200,7 +200,7 @@ class BMMFunctionHandler(NodeHandler): mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} return mapping - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: generators = [] op_data_mapping = self.get_operation_data_mapping() generators = [] diff --git a/colossalai/auto_parallel/solver/op_handler/getitem_handler.py b/colossalai/auto_parallel/solver/node_handler/getitem_handler.py similarity index 88% rename from colossalai/auto_parallel/solver/op_handler/getitem_handler.py rename to colossalai/auto_parallel/solver/node_handler/getitem_handler.py index 71022ccdd..9c4c8fdf0 100644 --- a/colossalai/auto_parallel/solver/op_handler/getitem_handler.py +++ b/colossalai/auto_parallel/solver/node_handler/getitem_handler.py @@ -1,7 +1,7 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector -from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector +from ..strategy import TensorStrategyGenerator, TensorTupleStrategyGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry import operator @@ -15,7 +15,7 @@ class GetItemHandler(NodeHandler): A GetItemHandler which deals with the sharding strategies for operator.getitem. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] if isinstance(op_data_mapping["input"].data, torch.Tensor): diff --git a/colossalai/auto_parallel/solver/op_handler/layer_norm_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py similarity index 89% rename from colossalai/auto_parallel/solver/op_handler/layer_norm_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py index 8125265a2..1bcb55daa 100644 --- a/colossalai/auto_parallel/solver/op_handler/layer_norm_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/layer_norm_handler.py @@ -1,7 +1,7 @@ import torch from .node_handler import ModuleHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from ..strategy import LayerNormGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from ..strategy import LayerNormGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry @@ -14,7 +14,7 @@ class LayerNormModuleHandler(ModuleHandler): A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(LayerNormGenerator(op_data_mapping, self.device_mesh)) diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/node_handler/node_handler.py similarity index 94% rename from colossalai/auto_parallel/solver/op_handler/node_handler.py rename to colossalai/auto_parallel/solver/node_handler/node_handler.py index e667049db..6634a75dc 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/node_handler/node_handler.py @@ -3,8 +3,8 @@ from torch.fx.node import Node from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from typing import Dict, List, Union -from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem -from ..strategy import StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, StrategiesVector, OperationData, TrainCycleItem +from ..strategy import StrategyGenerator from .._utils import generate_resharding_costs @@ -30,7 +30,7 @@ class NodeHandler(ABC): self.device_mesh = device_mesh self.strategies_vector = strategies_vector - def update_resharding_cost(self, strategy: ShardingStrategy_V2) -> None: + def update_resharding_cost(self, strategy: ShardingStrategy) -> None: """ Compute the resharding costs and save the costs in the ShardingStrategy object. """ @@ -97,13 +97,13 @@ class NodeHandler(ABC): return self.strategies_vector - def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]: + def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: # tranform the strategy generated # e.g. to process the sharding strategy for the transposed weights return strategy @abstractmethod - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: """ Define which generators should be used by this NodeHandler object. """ diff --git a/colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py b/colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py similarity index 88% rename from colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py rename to colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py index 59baa9631..7238085a5 100644 --- a/colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py +++ b/colossalai/auto_parallel/solver/node_handler/normal_pooling_handler.py @@ -1,12 +1,12 @@ import torch import torch.nn.functional as F from .node_handler import ModuleHandler, NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from ..strategy import NormalPoolStrategyGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry -__all__ = ['LinearModuleHandler', 'LinearFunctionHandler'] +__all__ = ['NormPoolingHandler'] @operator_registry.register(torch.nn.MaxPool1d) @@ -20,7 +20,7 @@ class NormPoolingHandler(ModuleHandler): A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(NormalPoolStrategyGenerator(op_data_mapping, self.device_mesh)) diff --git a/colossalai/auto_parallel/solver/op_handler/output_handler.py b/colossalai/auto_parallel/solver/node_handler/output_handler.py similarity index 89% rename from colossalai/auto_parallel/solver/op_handler/output_handler.py rename to colossalai/auto_parallel/solver/node_handler/output_handler.py index 55ff2f843..a268bcc04 100644 --- a/colossalai/auto_parallel/solver/op_handler/output_handler.py +++ b/colossalai/auto_parallel/solver/node_handler/output_handler.py @@ -1,7 +1,7 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector -from colossalai.auto_parallel.solver.strategy import StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector +from colossalai.auto_parallel.solver.strategy import StrategyGenerator from colossalai.auto_parallel.solver.strategy.output_generator import OutputGenerator from typing import List, Dict from .registry import operator_registry @@ -14,7 +14,7 @@ class OuputHandler(NodeHandler): A OuputHandler which deals with the sharding strategies for Output Node. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(OutputGenerator(op_data_mapping, self.device_mesh, self.predecessor_node)) diff --git a/colossalai/auto_parallel/solver/op_handler/placeholder_handler.py b/colossalai/auto_parallel/solver/node_handler/placeholder_handler.py similarity index 86% rename from colossalai/auto_parallel/solver/op_handler/placeholder_handler.py rename to colossalai/auto_parallel/solver/node_handler/placeholder_handler.py index 4c4f0a83a..ab6b02a7b 100644 --- a/colossalai/auto_parallel/solver/op_handler/placeholder_handler.py +++ b/colossalai/auto_parallel/solver/node_handler/placeholder_handler.py @@ -1,7 +1,7 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData -from colossalai.auto_parallel.solver.strategy import StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData +from colossalai.auto_parallel.solver.strategy import StrategyGenerator from colossalai.auto_parallel.solver.strategy.placeholder_generator import PlaceholderGenerator from typing import List, Dict from .registry import operator_registry @@ -14,7 +14,7 @@ class PlacehodlerHandler(NodeHandler): A PlacehodlerHandler which deals with the sharding strategies for Placeholder Node. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(PlaceholderGenerator(op_data_mapping, self.device_mesh)) diff --git a/colossalai/auto_parallel/solver/op_handler/registry.py b/colossalai/auto_parallel/solver/node_handler/registry.py similarity index 100% rename from colossalai/auto_parallel/solver/op_handler/registry.py rename to colossalai/auto_parallel/solver/node_handler/registry.py diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/reshape_handler.py similarity index 80% rename from colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/reshape_handler.py index 76ce1a766..8bb779290 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/reshape_handler.py @@ -1,23 +1,23 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector -from ..strategy import ReshapeGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector +from ..strategy import ReshapeGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry import operator -__all__ = ['ReshapeHandler_V2'] +__all__ = ['ReshapeHandler'] @operator_registry.register(torch.reshape) @operator_registry.register(torch.flatten) @operator_registry.register(torch.Tensor.permute) -class ReshapeHandler_V2(NodeHandler): +class ReshapeHandler(NodeHandler): """ A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) diff --git a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py similarity index 82% rename from colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py index 75b59f827..73ea4e6b9 100644 --- a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/unary_elementwise_handler.py @@ -1,22 +1,22 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector -from ..strategy import UnaryElementwiseGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector +from ..strategy import UnaryElementwiseGenerator, StrategyGenerator from typing import List, Dict from .registry import operator_registry import operator -__all__ = ['UnaryElementwiseHandler_V2'] +__all__ = ['UnaryElementwiseHandler'] @operator_registry.register(torch.abs) @operator_registry.register(torch.nn.ReLU) -class UnaryElementwiseHandler_V2(NodeHandler): +class UnaryElementwiseHandler(NodeHandler): """ A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) diff --git a/colossalai/auto_parallel/solver/op_handler/utils.py b/colossalai/auto_parallel/solver/node_handler/utils.py similarity index 100% rename from colossalai/auto_parallel/solver/op_handler/utils.py rename to colossalai/auto_parallel/solver/node_handler/utils.py diff --git a/colossalai/auto_parallel/solver/op_handler/where_handler_v2.py b/colossalai/auto_parallel/solver/node_handler/where_handler.py similarity index 93% rename from colossalai/auto_parallel/solver/op_handler/where_handler_v2.py rename to colossalai/auto_parallel/solver/node_handler/where_handler.py index 3dbe3f463..1e97ea919 100644 --- a/colossalai/auto_parallel/solver/op_handler/where_handler_v2.py +++ b/colossalai/auto_parallel/solver/node_handler/where_handler.py @@ -1,7 +1,7 @@ import torch from .node_handler import NodeHandler -from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector -from ..strategy import WhereGenerator, StrategyGenerator_V2 +from ..sharding_strategy import ShardingStrategy, OperationDataType, OperationData, StrategiesVector +from ..strategy import WhereGenerator, StrategyGenerator from .broadcast import recover_sharding_spec_for_broadcast_shape from typing import List, Dict from .registry import operator_registry @@ -17,7 +17,7 @@ class WhereHandler(NodeHandler): A WhereHandler which deals with the sharding strategies for torch.where. """ - def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + def get_strategy_generator(self) -> List[StrategyGenerator]: logical_op_data_mapping, _ = self.get_operation_data_mapping() generators = [] generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh)) @@ -73,7 +73,7 @@ class WhereHandler(NodeHandler): self.strategies_vector = list(strategies_vector) return self.strategies_vector - def post_process(self, strategy: ShardingStrategy_V2): + def post_process(self, strategy: ShardingStrategy): logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() for key in logical_op_data_mapping.keys(): logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] diff --git a/colossalai/auto_parallel/solver/op_handler/__init__.py b/colossalai/auto_parallel/solver/op_handler/__init__.py deleted file mode 100644 index 9c7e2e595..000000000 --- a/colossalai/auto_parallel/solver/op_handler/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .operator_handler import OperatorHandler -from .dot_handler import DotHandler -from .conv_handler import ConvHandler -from .batch_norm_handler import BatchNormHandler -from .reshape_handler import ReshapeHandler -from .bcast_op_handler import BcastOpHandler -from .embedding_handler import EmbeddingHandler -from .unary_elementwise_handler import UnaryElementwiseHandler -from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler -from .layer_norm_handler_v2 import LayerNormModuleHandler -from .batch_norm_handler_v2 import BatchNormModuleHandler -from .conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler -from .where_handler_v2 import WhereHandler -from .unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2 -from .reshape_handler_v2 import ReshapeHandler_V2 -from .placeholder_handler import PlacehodlerHandler -from .output_handler import OuputHandler - -__all__ = [ - 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', - 'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler', - 'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler', - 'UnaryElementwiseHandler_V2', 'ReshapeHandler_V2', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler' -] diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 990710aef..5973c7250 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -13,37 +13,7 @@ from typing import Dict, List, Union, Tuple, Any from torch.fx.node import Node from .constants import * -__all__ = ['ShardingStrategy', 'StrategiesVector'] - - -@dataclass -class ShardingStrategy: - ''' - ShardingStrategy is a structure containing sharding strategies of inputs and output of this node - and costs information using in solver. - - Argument: - name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. - output_sharding_spec(ShardingSpec): ShardingSpec of the output node. - compute_cost(float): Computation cost to complete this strategy.(default to 0) - communication_cost(float): Communication cost to complete this strategy.(default to 0) - memory_cost(float): Memory cost of the output node using this strategy.(default to 0) - resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list - with j-th strategy in its strategies_vector transforms to sharding spec wanted in this - strategy.(default to None) - input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. - ''' - - name: str - # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor. - output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]] - compute_cost: float = 0. - communication_cost: float = 0. - memory_cost: float = 0. - resharding_costs: Dict[Node, List[float]] = None - # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input. - # Therefore, we could process them at the specific op(operator.getitem) - input_shardings: List[ShardingSpec] = None +__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector'] class OperationDataType(Enum): @@ -111,7 +81,7 @@ class MemoryCost: @dataclass -class ShardingStrategy_V2: +class ShardingStrategy: """ ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node. @@ -178,13 +148,13 @@ class ShardingStrategy_V2: communication_cost = deepcopy(self.communication_cost) memory_cost = deepcopy(self.memory_cost) - return ShardingStrategy_V2(name=self.name, - sharding_specs=sharding_specs, - compute_cost=compute_cost, - communication_cost=communication_cost, - memory_cost=memory_cost, - communication_actions=communication_actions, - resharding_costs=resharding_costs) + return ShardingStrategy(name=self.name, + sharding_specs=sharding_specs, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + communication_actions=communication_actions, + resharding_costs=resharding_costs) class StrategiesVector(list): diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py index 0da540cae..f1bfa78bb 100644 --- a/colossalai/auto_parallel/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -1,16 +1,14 @@ from torch.fx import Graph, Node -from colossalai.auto_parallel.solver.op_handler.bcast_op_handler import BcastOpHandler -from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy_V2 +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.auto_parallel.solver.op_handler.registry import operator_registry -from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler -from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler +from colossalai.auto_parallel.solver.node_handler.registry import operator_registry +from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler from .options import SolverOptions from . import ShardingStrategy, StrategiesVector -from .op_handler import * +from .node_handler import * from .constants import * from copy import deepcopy import math @@ -20,7 +18,7 @@ from typing import Dict, List from ._utils import generate_sharding_spec, generate_resharding_costs import builtins -__all__ = ['StrategiesConstructor', 'StrategiesConstructor_V2'] +__all__ = ['StrategiesConstructor'] class StrategiesConstructor: @@ -33,412 +31,6 @@ class StrategiesConstructor: solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. """ - def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): - self.graph = graph - assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' - self.root_module = self.graph.owning_module - self.nodes = list(graph.nodes) - self.device_mesh = device_mesh - self.leaf_strategies = [] - self.strategy_map = {} - self.solver_options = solver_options - - def remove_duplicated_strategy(self, strategies_vector): - ''' - In build_strategies_and_cost method, we may produce some duplicated strategies. - In this method, we will remove the duplicated strategies depending on the strategies name. - ''' - name_checklist = [] - remove_list = [] - for strategy in strategies_vector: - if strategy.name not in name_checklist: - name_checklist.append(strategy.name) - else: - remove_list.append(strategy) - - for strategy in remove_list: - strategies_vector.remove(strategy) - - def _is_bcast_matmul(self, node): - is_bcast_matmul = False - if node.target is torch.matmul and len(node.args) == 2: - lhs_data = node.args[0]._meta_data - rhs_data = node.args[1]._meta_data - if lhs_data.dim() >= 3 and rhs_data.dim() >= 3: - is_bcast_matmul = True - return is_bcast_matmul - - def build_strategies_and_cost(self): - for node in self.nodes: - strategies_vector = StrategiesVector(node) - input_nodes_len = 0 - for check_node in strategies_vector.predecessor_nodes: - if isinstance(check_node._meta_data, torch.Tensor): - input_nodes_len += 1 - # input_nodes_len = len(strategies_vector.predecessor_nodes) - # placeholder node - if node.op == 'placeholder': - # For placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the placeholder node. - # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. - - if self.solver_options.fast: - # create sharding strategy for placeholder - name = 'Replica Placeholder' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_placeholder = ShardingStrategy(name, - output_sharding_spec, - memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_placeholder) - - # get_attr node - if node.op == 'get_attr': - # Same as placeholder nodes, if solver_options.fast is True, we just let them in - # fully replicate status, then strategies of following node will be treated equally due - # to replicate status has no resharding cost to other status. At the same time, the searching - # space is smaller than enumerating all the possible sharding spec for the get_attr node. - # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. - if self.solver_options.fast: - # create sharding strategy for get_attr - name = 'Replica Attribute' - dim_partition_dict = {} - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) - strategies_vector.append(sharding_strategy_attribute) - - # call_module node - if node.op == 'call_module': - - target = node.target - submod = self.root_module.get_submodule(target) - submod_type = type(submod) - - # conv module - if submod_type in CONV_MODULE_OP: - # use ConvHandler to create sharding strategies for conv module node - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) - conv_handler.register_strategy() - - # linear module - elif submod_type in LINEAR_MODULE_OP: - # use DotHandler to create sharding strategies for linear module node - dot_handler = DotHandler(node, self.device_mesh, strategies_vector) - dot_handler.register_strategy() - - # element-wise module - elif submod_type in ELEMENTWISE_MODULE_OP: - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - # BatchNormNd module - elif submod_type in BATCHNORM_MODULE_OP: - # create sharding strategy for element-wise module - norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) - norm_handler.register_strategy() - # for strategy in norm_handler.strategies_vector: - # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') - # assert False - - # MaxPool module - elif submod_type in POOL_MODULE_OP: - # TODO: add sharding constraints on image dimension - # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension - - # create sharding strategy for element-wise module - assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.' - input_node = strategies_vector.predecessor_nodes[0] - # For element-wise module, we keep the sharding spec of output node same as - # the input. Therefore, the different strategies of input node with same - # output sharding spec will generate same strategy for element-wise module. - sharding_spec_checklist = [] - for strategy in input_node.strategies_vector: - # It looks a little bit confusing, the input of the processing node - # is the output of the input_node. - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - if input_sharding_spec in sharding_spec_checklist: - continue - - sharding_spec_checklist.append(input_sharding_spec) - dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) - output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) - - name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' - - # TODO: use meta_info_prop to profile memory cost and compute cost - compute_cost = node._meta_data.numel() - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - strategies_vector.append(sharding_strategy) - - # embedding module - elif submod_type in EMBEDDING_MODULE_OP: - embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector) - embedding_handler.register_strategy() - - # layernorm module - elif submod_type in LAYERNORM_MODULE_OP: - layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector) - layernorm_handler.register_strategy() - # other module - else: - raise RuntimeError(f'{submod_type} module is NOT supported now.') - - # call_function node - if node.op == 'call_function': - target = node.target - # conv function - if target in CONV_FUNC_OP: - # use ConvHandler to create sharding strategies for conv node - # TODO: the operator_handler does NOT support function node processing now. - conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) - conv_handler.register_strategy() - - # linear function - elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node): - # use DotHandler to create sharding strategies for linear node - # TODO: the operator_handler does NOT support function node processing now. - linear_handler = DotHandler(node, self.device_mesh, strategies_vector) - linear_handler.register_strategy() - - # where function - elif target == torch.where: - if input_nodes_len == 1: - # both of x and y are scalar - pass - - elif input_nodes_len == 2: - # one of x or y is type of scalar - pass - - else: - # general case - where_handler = WhereHandler(node, self.device_mesh, strategies_vector) - where_handler.register_strategy() - - # reshape function - elif target in RESHAPE_FUNC_OP: - # use ReshapeHandler to create sharding strategies for rehsape node - reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) - reshape_handler.register_strategy() - - # element-wise function - elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - # bcast op - elif target in BCAST_FUNC_OP: - if isinstance(node._meta_data, torch.Tensor): - bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) - bcast_op_handler.register_strategy() - - # torch.var_mean - elif target == torch.var_mean: - dim = node.kwargs['dim'] - input_tensor_node = strategies_vector.predecessor_nodes[0] - for strategy in input_tensor_node.strategies_vector: - input_sharding_spec = strategy.output_sharding_spec - assert isinstance(input_sharding_spec, - ShardingSpec), f'The input node should NOT be a tuple of tensor.' - entire_shape_input = input_sharding_spec.entire_shape - dim_partition_dict_input = input_sharding_spec.dim_partition_dict - name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})' - if dim in dim_partition_dict_input: - # We need to make the action dimension in replicate status - dim_partition_dict_for_input = deepcopy(dim_partition_dict_input) - dim_partition_dict_for_input.pop(dim) - new_input_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_input, - dim_partition_dict=dim_partition_dict_for_input) - entire_shape_output = deepcopy(entire_shape_input) - entire_shape_output.pop(dim) - dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_input) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [new_input_sharding_spec]) - sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[new_input_sharding_spec]) - - else: - entire_shape_output = deepcopy(entire_shape_input) - entire_shape_output.pop(dim) - dim_partition_dict_for_output = deepcopy(dim_partition_dict_input) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partion_dict=dim_partition_dict_input) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec]) - sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[input_sharding_spec]) - - strategies_vector.append(sharding_strategy) - - # operator.getitem - elif target == operator.getitem: - index = node.args[1] - input_tensor_node = strategies_vector.predecessor_nodes[0] - for strategy in input_tensor_node.strategies_vector: - if isinstance(strategy.output_sharding_spec, ShardingSpec): - input_sharding_spec = strategy.output_sharding_spec - else: - input_sharding_spec = strategy.output_sharding_spec[index] - assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' - dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) - entire_shape_output = deepcopy(input_sharding_spec.entire_shape) - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. - compute_cost = 0 - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - [input_sharding_spec], - index=index) - # to prevent the resharding happening, set their resharding cost to inf. - resharding_costs[input_tensor_node] = [ - cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node] - ] - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=compute_cost, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=[strategy.output_sharding_spec]) - strategies_vector.append(sharding_strategy) - - # torch.arange function - elif target == torch.arange: - name = f'FULLY REPLICATED ARANGE' - entire_shape_output = node._meta_data.shape - dim_partition_dict_for_output = {} - output_sharding_spec = ShardingSpec(self.device_mesh, - entire_shape_output, - dim_partition_dict=dim_partition_dict_for_output) - memory_cost = node._meta_data.numel() - sharding_strategy = ShardingStrategy(name, - output_sharding_spec, - compute_cost=0, - memory_cost=memory_cost) - strategies_vector.append(sharding_strategy) - - # op list to be processed to support gpt2 - elif target in (builtins.getattr, operator.le, torch.addmm): - pass - # other function - else: - raise RuntimeError(f'{target} function is NOT supported now.') - - # call_method node - if node.op == 'call_method': - method = getattr(node.args[0]._meta_data.__class__, node.target) - if method in (torch.Tensor.size,): - pass - elif method in ELEMENTWISE_METHOD_OP: - unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) - unary_elementwise_handler.register_strategy() - - elif method in RESHAPE_METHOD_OP: - reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) - reshape_handler.register_strategy() - # print(strategies_vector) - # if len(strategies_vector) == 0: - # print(node) - # assert False - else: - raise RuntimeError(f'{method} function is NOT supported now.') - - # output node - if node.op == 'output': - if self.solver_options.fast: - # create sharding strategy for output - name = 'Replica Output' - input_nodes = strategies_vector.predecessor_nodes - input_sharding_specs = [] - for input_node in input_nodes: - dim_partition_dict_for_input = {} - entire_shape = input_node._meta_data.shape - sharding_spec = ShardingSpec(self.device_mesh, - entire_shape, - dim_partition_dict=dim_partition_dict_for_input) - input_sharding_specs.append(sharding_spec) - - dim_partition_dict = {} - output_sharding_spec = input_sharding_specs - # TODO: use meta_info_prop to profile memory cost - memory_cost = 0 - resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, - input_sharding_specs) - - # clear the resharding cost for the output node - # TODO: we may remove this in final version - for prev_node, resharding_cost_list in resharding_costs.items(): - resharding_costs[prev_node] = [0] * len(resharding_cost_list) - - sharding_strategy_attribute = ShardingStrategy(name, - output_sharding_spec, - memory_cost=memory_cost, - resharding_costs=resharding_costs, - input_shardings=tuple(input_sharding_specs)) - strategies_vector.append(sharding_strategy_attribute) - - self.remove_duplicated_strategy(strategies_vector) - setattr(node, 'strategies_vector', strategies_vector) - self.leaf_strategies.append(strategies_vector) - self.strategy_map[node] = strategies_vector - - # remove no strategy nodes - remove_list = [] - for strategies_vector in self.leaf_strategies: - if len(strategies_vector) == 0: - remove_list.append(strategies_vector.node) - for node in remove_list: - if node.strategies_vector in self.leaf_strategies: - self.leaf_strategies.remove(node.strategies_vector) - if node in self.strategy_map: - self.strategy_map.pop(node) - - -class StrategiesConstructor_V2: - """ - StrategiesConstructor is used to construct the parallelization plan for the model execution. - - Args: - graph (Graph): a Graph object used for analysis and strategy generation. - device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. - solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. - """ - def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): self.graph = graph assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index a71b0e03e..9d0e98c01 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -1,4 +1,4 @@ -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from .matmul_strategy_generator import DotProductStrategyGenerator, MatVecStrategyGenerator, LinearProjectionStrategyGenerator, BatchedMatMulStrategyGenerator from .conv_strategy_generator import ConvStrategyGenerator from .batch_norm_generator import BatchNormStrategyGenerator @@ -11,11 +11,10 @@ from .normal_pooling_generator import NormalPoolStrategyGenerator from .placeholder_generator import PlaceholderGenerator from .output_generator import OutputGenerator - __all__ = [ - 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', - 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', - 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', - 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', - 'WhereGenerator', 'ReshapeGenerator', 'NormalPoolStrategyGenerator' + 'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', + 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', + 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator', + 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', + 'ReshapeGenerator', 'NormalPoolStrategyGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py index 3e7302c27..1964c6eb8 100644 --- a/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/batch_norm_generator.py @@ -1,8 +1,8 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List from .._utils import exception_handler import copy @@ -10,7 +10,7 @@ import copy __all__ = ['BatchNormStrategyGenerator'] -class BatchNormStrategyGenerator(StrategyGenerator_V2): +class BatchNormStrategyGenerator(StrategyGenerator): """ A StrategyGenerator which deals with the sharding strategies of batch normalization. @@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): assert input_op_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): ''' Compute the computation cost per device with this specific strategy. @@ -64,7 +64,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), diff --git a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py index 58c76cc96..fcab52012 100644 --- a/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py @@ -1,15 +1,15 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List from .._utils import exception_handler import warnings import copy -class ConvStrategyGenerator(StrategyGenerator_V2): +class ConvStrategyGenerator(StrategyGenerator): """ ConvStrategyGenerator is a generic class to generate strategies. The operation data is defined as `output = input x other + bias`. @@ -30,7 +30,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): assert input_op_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): ''' Compute the computation cost per device with this specific strategy. @@ -70,7 +70,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): forward_size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), @@ -455,7 +455,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategies = [] # SS = SR x RS strategies.append(self.split_input_batch_weight_out_channel(0, 1)) diff --git a/colossalai/auto_parallel/solver/strategy/getitem_generator.py b/colossalai/auto_parallel/solver/strategy/getitem_generator.py index 0e1287eae..646213032 100644 --- a/colossalai/auto_parallel/solver/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/solver/strategy/getitem_generator.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import FollowingStrategyGenerator from typing import List @@ -28,11 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py index 3049b5b4c..00bb0a8ca 100644 --- a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py @@ -1,8 +1,8 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding import copy @@ -10,7 +10,7 @@ import copy __all__ = ['LayerNormGenerator'] -class LayerNormGenerator(StrategyGenerator_V2): +class LayerNormGenerator(StrategyGenerator): """ LayerNormGenerator is a generic class to generate strategies for LayerNorm operation. The operation data is defined as `output = input x other + bias`. @@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): ''' Compute the computation cost per device with this specific strategy. @@ -54,7 +54,7 @@ class LayerNormGenerator(StrategyGenerator_V2): compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py index a5a9ec58a..806959bb3 100644 --- a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py @@ -1,13 +1,13 @@ from audioop import bias import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List -class MatMulStrategyGenerator(StrategyGenerator_V2): +class MatMulStrategyGenerator(StrategyGenerator): """ MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. The operation data is defined as `output = input x other + bias`. @@ -17,7 +17,7 @@ class MatMulStrategyGenerator(StrategyGenerator_V2): def has_bias(self): return 'bias' in self.op_data - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), @@ -53,7 +53,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): other_op_data = self.op_data['other'] assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = sharded_input_shape * 2 @@ -88,7 +88,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategy_list = [] # do not split dimensions for dot product @@ -139,7 +139,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategy_list = [] # no split @@ -154,7 +154,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C = AB # C: [M, N], A: [M, P], B: [P, N] # fwd cost = MNP (only count mul) @@ -172,7 +172,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategies = [] # SS = SR x RS @@ -500,7 +500,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): other_op_data = self.op_data['other'] assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) def split_one_batch_dim(self, mesh_dim): @@ -645,7 +645,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategy_list = [] device_mesh_is_1d = True if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: diff --git a/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py b/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py index a6d416797..7ea2d7e0a 100644 --- a/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py +++ b/colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py @@ -1,14 +1,14 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding import copy -class NormalPoolStrategyGenerator(StrategyGenerator_V2): +class NormalPoolStrategyGenerator(StrategyGenerator): """ NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, @@ -26,7 +26,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): assert input_op_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: + def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: ''' Compute the computation cost per device with this specific strategy. @@ -54,7 +54,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) return compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: forward_size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'output': self._compute_size_in_bytes(strategy, "output") @@ -101,7 +101,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): return dim_partition_list - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: strategy_list = [] dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) diff --git a/colossalai/auto_parallel/solver/strategy/output_generator.py b/colossalai/auto_parallel/solver/strategy/output_generator.py index a6cde6b64..bfd2ee9fd 100644 --- a/colossalai/auto_parallel/solver/strategy/output_generator.py +++ b/colossalai/auto_parallel/solver/strategy/output_generator.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import OutputStrategyGenerator from typing import List @@ -18,11 +18,11 @@ class OutputGenerator(OutputStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/placeholder_generator.py b/colossalai/auto_parallel/solver/strategy/placeholder_generator.py index b5c65e615..5e1940166 100644 --- a/colossalai/auto_parallel/solver/strategy/placeholder_generator.py +++ b/colossalai/auto_parallel/solver/strategy/placeholder_generator.py @@ -1,8 +1,8 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2 +from .strategy_generator import StrategyGenerator from typing import List from .._utils import exception_handler import copy @@ -10,7 +10,7 @@ import copy __all__ = ['PlaceholderGenerator'] -class PlaceholderGenerator(StrategyGenerator_V2): +class PlaceholderGenerator(StrategyGenerator): """ PlaceholderGenerator is a generic class to generate strategies for placeholder node. """ @@ -18,11 +18,11 @@ class PlaceholderGenerator(StrategyGenerator_V2): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/reshape_generator.py b/colossalai/auto_parallel/solver/strategy/reshape_generator.py index 401764aed..4ec45f5d3 100644 --- a/colossalai/auto_parallel/solver/strategy/reshape_generator.py +++ b/colossalai/auto_parallel/solver/strategy/reshape_generator.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import FollowingStrategyGenerator from typing import List @@ -17,11 +17,11 @@ class ReshapeGenerator(FollowingStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py index ec7d96298..06bfe2a35 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -7,12 +7,12 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from typing import Dict, List, Union, Any -from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType +from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType from torch.fx import Node import copy -class StrategyGenerator_V2(ABC): +class StrategyGenerator(ABC): """ StrategyGenerator is used to generate the same group of sharding strategies. @@ -38,9 +38,7 @@ class StrategyGenerator_V2(ABC): """ sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping) communication_actions = self.replace_op_name_with_op_data(communication_action_mapping) - return ShardingStrategy_V2(name=name, - sharding_specs=sharding_specs, - communication_actions=communication_actions) + return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions) def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): """ @@ -85,7 +83,7 @@ class StrategyGenerator_V2(ABC): sharding_spec=sharding_spec, logical_process_axis=logical_process_axis) - def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Compute the communication cost involved in the forward and backward iteration. """ @@ -113,20 +111,20 @@ class StrategyGenerator_V2(ABC): return strategy @abstractmethod - def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the computation flops. """ pass @abstractmethod - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: """ Customize this method to compute the memory cost in bytes. """ pass - def _compute_size_in_bytes(self, strategy: ShardingStrategy_V2, key: str): + def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str): """ Compute the size of a tensor in bytes. @@ -142,7 +140,7 @@ class StrategyGenerator_V2(ABC): return reduce(operator.mul, sharded_shape) * size_per_elem_bytes @abstractmethod - def generate(self) -> List[ShardingStrategy_V2]: + def generate(self) -> List[ShardingStrategy]: """ Generate all possible sharding strategies for this operation. """ @@ -157,7 +155,7 @@ class StrategyGenerator_V2(ABC): pass -class FollowingStrategyGenerator(StrategyGenerator_V2): +class FollowingStrategyGenerator(StrategyGenerator): """ FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. @@ -171,7 +169,7 @@ class FollowingStrategyGenerator(StrategyGenerator_V2): self.predecessor_node = predecessor_node -class OutputStrategyGenerator(StrategyGenerator_V2): +class OutputStrategyGenerator(StrategyGenerator): """ OutputStrategyGenerator is used to generate the sharding strategies for Output Node. """ diff --git a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py index 99db359e4..2a9220ca3 100644 --- a/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py @@ -1,6 +1,6 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import FollowingStrategyGenerator from typing import List @@ -18,11 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/solver/strategy/where_generator.py b/colossalai/auto_parallel/solver/strategy/where_generator.py index bceb3c42a..bbf987ef2 100644 --- a/colossalai/auto_parallel/solver/strategy/where_generator.py +++ b/colossalai/auto_parallel/solver/strategy/where_generator.py @@ -1,8 +1,8 @@ import operator from functools import reduce -from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern -from .strategy_generator import StrategyGenerator_V2, FollowingStrategyGenerator +from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator from typing import List from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding import copy @@ -10,7 +10,7 @@ import copy __all__ = ['WhereGenerator'] -class WhereGenerator(StrategyGenerator_V2): +class WhereGenerator(StrategyGenerator): """ WhereGenerator is a generic class to generate strategies for Where operation. """ @@ -18,11 +18,11 @@ class WhereGenerator(StrategyGenerator_V2): def validate(self) -> bool: return super().validate() - def update_compute_cost(self, strategy: ShardingStrategy_V2): + def update_compute_cost(self, strategy: ShardingStrategy): compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2): + def update_memory_cost(self, strategy: ShardingStrategy): ''' Compute the memory cost per device with this specific strategy. ''' diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py new file mode 100644 index 000000000..a081ce69c --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py @@ -0,0 +1,6 @@ +from .options import SolverOptions +from .strategies_constructor import StrategiesConstructor +from .sharding_strategy import ShardingStrategy, StrategiesVector +from .cost_graph import CostGraph +from .solver import Solver +from .graph_analysis import GraphAnalyser \ No newline at end of file diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py new file mode 100644 index 000000000..378a14d03 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py @@ -0,0 +1,139 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +import torch +from torch.fx.node import Node +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from typing import Union, Dict, List, Optional +import warnings +from functools import reduce +import functools +import operator +from .constants import INFINITY_COST + + +def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, + dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict. + + + Args: + input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + + if isinstance(input_, Node): + assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data' + meta_tensor = input_._meta_data + assert meta_tensor is not None, "The given node's _meta_data attribute is None" + shape = meta_tensor.shape + elif isinstance(input_, torch.Tensor): + shape = input_.shape + else: + raise TypeError( + f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.' + ) + for dim_index, sharding_index_list in dim_partition_dict.items(): + sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' + + sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict) + return sharding_spec + + +def generate_resharding_costs(nodes: List[Node], + sharding_specs: List[ShardingSpec], + count_backward: Optional[bool] = True, + dtype: Optional[torch.dtype] = None, + index=None): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + nodes (List[Node]): a list of nodes + sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. + count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs = {} + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # shape consistency manager is a singleton class + shape_consistency_manager = ShapeConsistencyManager() + + for input_node, input_spec in zip(nodes, sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + if not isinstance(input_sharding_spec, ShardingSpec): + assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.' + input_sharding_spec = input_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + try: + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( + input_sharding_spec, input_spec) + + # we need multiply the size of elem dtype to get correct communication cost + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes + except AssertionError as e: + warnings.warn(f'{e}') + resharding_cost = INFINITY_COST + resharding_costs[input_node].append(resharding_cost) + return resharding_costs + + +def exception_handler(func): + """ + A function wrapper which executes the function with a specified seed. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + rst = func(*args, **kwargs) + return rst + except AssertionError as e: + warnings.warn(f'{e}') + + return wrapper + + +def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size): + dim_partition_list = [] + # enumerate all the 2D sharding cases + for i in range(dim_size): + for j in range(i + 1, dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]} + dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + dim_partition_list.append(dim_partition_dict_1) + for i in range(dim_size): + dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]} + dim_partition_list.append(dim_partition_dict_flatten) + + return dim_partition_list + + +def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size): + dim_partition_list = [] + # enumerate all the 1D sharding cases + for i in range(dim_size): + dim_partition_dict_0 = {i: [mesh_dim_0]} + dim_partition_list.append(dim_partition_dict_0) + + return dim_partition_list + + +def generate_sharding_size(dim_partition_dict, device_mesh): + total_sharding_size = 1 + for mesh_dim_list in dim_partition_dict.values(): + mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list] + sharding_size = reduce(operator.mul, mesh_dim_sharding_size) + total_sharding_size *= sharding_size + + return total_sharding_size diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py new file mode 100644 index 000000000..91c20d343 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py @@ -0,0 +1,83 @@ +import torch +import operator + +__all__ = [ + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', + 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', + 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST' +] + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + # softmax should not be here + torch.nn.functional.softmax +] +ELEMENTWISE_METHOD_OP = [ + torch.Tensor.to, + torch.Tensor.type, + # TODO: contiguous maybe need some extra processes. + torch.Tensor.contiguous +] +RESHAPE_FUNC_OP = [torch.flatten, torch.reshape] +RESHAPE_METHOD_OP = [ + torch.Tensor.view, + torch.Tensor.unsqueeze, + torch.Tensor.split, + torch.Tensor.permute, + torch.Tensor.transpose, +] +BCAST_FUNC_OP = [ + torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub, + operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh +] +CONV_MODULE_OP = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d +] +CONV_FUNC_OP = [ + torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d +] +EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding] +LINEAR_MODULE_OP = [torch.nn.Linear] +LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] +BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm] +LAYERNORM_MODULE_OP = [torch.nn.LayerNorm] +POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d] +NON_PARAM_FUNC_OP = [ + torch.flatten, + torch.reshape, + torch.abs, + torch.cos, + torch.exp, + operator.neg, + torch.multiply, + torch.nn.functional.relu, + torch.nn.functional.dropout, + torch.flatten, + torch.where, + operator.pow, + torch.pow, + torch.tanh, + torch.add, + torch.sub, + torch.mul, + torch.div, + torch.floor_divide, + torch.true_divide, + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + # softmax should not be here + torch.nn.functional.softmax +] + +INFINITY_COST = 1e13 diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py new file mode 100644 index 000000000..239d02115 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py @@ -0,0 +1,172 @@ +from typing import List +import math +from torch.fx.node import Node +from .constants import INFINITY_COST + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies, simplify=True): + self.leaf_strategies = leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + # stores number of strategies in each node + self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies} + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self.simplify = simplify + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + if self.simplify: + self.merge_pair = [] + for strategies_vector in self.leaf_strategies: + # build edge_cost + dst_node = strategies_vector.node + for src_node in strategies_vector.predecessor_nodes: + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + # src_index = strategies_vector.predecessor_nodes.index(src_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(src_node.strategies_vector)): + edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j] + self.edge_costs[node_pair] = edge_cost + # add parents and children attribute to node + setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) + setattr(dst_node, 'children', strategies_vector.successor_nodes) + self._remove_invalid_node(dst_node, 'parents') + self._remove_invalid_node(dst_node, 'children') + + if self.simplify and strategies_vector.check_merge(): + for followed_node in strategies_vector.predecessor_nodes: + self.merge_pair.append((followed_node, dst_node)) + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + def merge_node(self, src_node, dst_node): + ''' + To merge dst_node into src_node, we need to do it in following steps: + + 1. For each strategy in dst_node, we need to pick an appropriate strategy + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node + strategies dispatching. For example, for the graph 0->1->2, after merging node 1 + into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] + x represents the picking strategy of node 1 merged into node 2 strategy 0. + + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs + contains two parts, one is resharding costs between src_node strategy and dst_node strategy, + another is the origin extra costs in src_node strategy. + + 3. Build connections between new node pairs, and remove the src_node after all consumer nodes + detached from it. + + Argument: + src_node(Node): The node will be merged into dst_node. + dst_node(Node): The node to integrate src_node. + ''' + src_node_index = dst_node.parents.index(src_node) + # build merge_map + merge_map = {} + for src_index, strategy in enumerate(src_node.strategies_vector): + min_cost = INFINITY_COST + lowest_cost_index = -1 + for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): + resharding_cost = dst_strategy.resharding_costs[src_node][src_index] + if resharding_cost <= min_cost: + min_cost = resharding_cost + lowest_cost_index = dst_index + merge_map[src_index] = lowest_cost_index + + # extra_node_cost for src node + self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node] + for src_index, strategy in enumerate(src_node.strategies_vector): + target_strate_index = merge_map[src_index] + target_strategy = dst_node.strategies_vector[target_strate_index] + self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index] + if dst_node in self.extra_node_costs: + self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index] + + # add new node pair to cost graph + for child_node in dst_node.children: + new_node_pair = (src_node, child_node) + old_node_pair = (dst_node, child_node) + if new_node_pair in self.edge_costs: + continue + edge_cost = {} + for i in range(self.node_lens[src_node]): + for j in range(self.node_lens[child_node]): + dst_strate_index = merge_map[i] + # dst_strategy = dst_node.strategies_vector[dst_strate_index] + edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] + if new_node_pair not in self.edge_costs: + self.edge_costs[new_node_pair] = edge_cost + else: + # we should accumulate the resharding costs if args of child node contain + # both src node and dst node. + for index_pair, resharding_cost in self.edge_costs[new_node_pair]: + self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair] + + # connect src node and children of dst node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for child_node in dst_node.children: + if child_node not in src_node.children: + src_node.children.append(child_node) + if src_node not in child_node.parents: + child_node.parents.append(src_node) + # remove dst node from cost graph when dst node has no producer. + if len(dst_node.parents) == 0: + child_node.parents.remove(dst_node) + node_pair = (dst_node, child_node) + self.edge_costs.pop(node_pair) + if len(dst_node.parents) == 0: + self.following_dict[dst_node] = src_node + dst_node.children = [] + + def _reindexing_src(self, src): + if src not in self.following_dict: + return src + return self._reindexing_src(self.following_dict[src]) + + def simplify_graph(self): + if not self.simplify: + return + self.merge_pair.reverse() + for (src_node, dst_node) in self.merge_pair: + self.merge_node(src_node, dst_node) + self.merge_pair.reverse() + reindexing_following_dict = {} + for dst, src in self.following_dict.items(): + reindexing_following_dict[dst] = self._reindexing_src(src) + self.following_dict = reindexing_following_dict diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py new file mode 100644 index 000000000..831e7eadd --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py @@ -0,0 +1,163 @@ +from dataclasses import dataclass +from torch.fx.node import Node +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from collections import OrderedDict as ODict +from typing import List, OrderedDict, Union, Any +from colossalai.fx.passes.utils import get_node_module + +__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] + + +@dataclass +class LiveVariable: + """ + LiveVariable is a data structure to store the meta information of a variable for liveness analysis. + """ + name: str + node: Node + is_inplace: bool + + +class LiveVariableVector(list): + """ + LiveVariableVector is a data structure to store the list of LiveVariable objects. + """ + + def exists(self, name) -> bool: + """ + Check if a variable has already existed in the current list by name. + """ + for var in self: + if name == var.name: + return True + return False + + def get(self, name) -> LiveVariable: + for var in self: + if name == var.name: + return var + raise KeyError(f"Variable {name} is not found") + + def copy(self) -> "LiveVariableVector": + """ + Create a copy of this vector + """ + vector = LiveVariableVector() + for var in self: + vector.append(var) + return vector + + +@dataclass +class LiveStage: + """ + LiveStage is a data structure to record the living variables at this current node. + """ + name: str + node: Node + all_live_vars: LiveVariableVector + unique_live_vars: LiveVariableVector + + +class GraphAnalyser: + + def __init__(self, gm: GraphModule): + self._gm = gm + self._graph = gm.graph + + @property + def gm(self) -> GraphModule: + """ + Return the GraphModule object associated with this analyser. + """ + return self._gm + + @property + def graph(self) -> Graph: + """ + Return the Graph object associated with this analyser. + """ + return self._graph + + def liveness_analysis(self) -> List[LiveStage]: + """ + Analyse the graph to obtain the variable liveness information. This function returns + an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. + """ + compute_nodes = self.graph.nodes + liveness_list = [] + + # checked: record all variables created since the first stage + # all: record the live variables only exist until the current stage. + # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # unique: record the unique live variables only exist until the current stage. + # this is different from `all list` as some variables are duplicated. + checked_variables = LiveVariableVector() + all_live_variables = LiveVariableVector() + unique_live_vars = LiveVariableVector() + + for idx, node in enumerate(compute_nodes): + ############################# + # find new living variables # + ############################# + # detect whether the current op is an in-place op + # if it is an in-place op, we would deem it as a duplciate var + is_inplace = False + if node.op == 'call_function': + # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) + if node.kwargs.get('inplace', False): + is_inplace = True + elif node.op == 'call_module': + # to check if this is an inplace op such as torch.nn.Relu(inplace=True) + module = get_node_module(node) + if getattr(module, 'inplace', False): + is_inplace = True + + # add the output var + meta = getattr(node, '_meta_data', None) + live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace) + if not is_inplace: + unique_live_vars.append(live_var) + checked_variables.append(live_var) + all_live_variables.append(live_var) + + # check if any input is not checked yet + for arg in node.args: + if not isinstance(arg, Node): + continue + arg_name = arg.name + if not checked_variables.exists(arg_name): + live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False) + all_live_variables.append(live_var_from_arg) + checked_variables.append(live_var_from_arg) + unique_live_vars.append(live_var_from_arg) + + # TODO: add the logic to remove live variables + # this should be completed if we are able to trace the backward compute graph + + # add this stage to liveness dict + stage = LiveStage(name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy()) + # if a LiveStage is covered by another LiveStage, we just keep the larger one. + replace = False + for index, prev_stage in enumerate(liveness_list): + all_covered = True + for ele in prev_stage.unique_live_vars: + if ele not in stage.unique_live_vars: + all_covered = False + break + if all_covered: + replace = True + break + if replace: + liveness_list[index] = stage + else: + liveness_list.append(stage) + + return liveness_list + + def get_alias_set(self): + pass diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py new file mode 100644 index 000000000..efcaae795 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py @@ -0,0 +1,14 @@ +from .operator_handler import OperatorHandler +from .dot_handler import DotHandler +from .conv_handler import ConvHandler +from .batch_norm_handler import BatchNormHandler +from .reshape_handler import ReshapeHandler +from .bcast_op_handler import BcastOpHandler +from .embedding_handler import EmbeddingHandler +from .unary_elementwise_handler import UnaryElementwiseHandler +from .where_handler import WhereHandler + +__all__ = [ + 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', + 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler' +] diff --git a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py similarity index 99% rename from colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py index 207f66107..76de2d149 100644 --- a/colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py @@ -1,9 +1,9 @@ import operator from functools import reduce import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['BatchNormHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py similarity index 99% rename from colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py index 1ca5d6559..eca6bed42 100644 --- a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py @@ -2,13 +2,13 @@ import operator from functools import reduce import warnings import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy from typing import Dict, List -from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding __all__ = ['BcastOpHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py similarity index 99% rename from colossalai/auto_parallel/solver/op_handler/conv_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py index 028561755..1208f86d3 100644 --- a/colossalai/auto_parallel/solver/op_handler/conv_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py @@ -2,9 +2,9 @@ import operator from functools import reduce import warnings import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['ConvHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py similarity index 99% rename from colossalai/auto_parallel/solver/op_handler/dot_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py index b6be639f4..549b58df8 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py @@ -2,11 +2,11 @@ import operator import torch import torch.nn as nn import torch.nn.functional as F -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from functools import reduce -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler from enum import Enum from .strategy_generator import StrategyGenerator, IntermediateStrategy from typing import List diff --git a/colossalai/auto_parallel/solver/op_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py similarity index 97% rename from colossalai/auto_parallel/solver/op_handler/embedding_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py index aac28a11d..45c001b60 100644 --- a/colossalai/auto_parallel/solver/op_handler/embedding_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py @@ -2,13 +2,13 @@ import operator from functools import reduce import warnings import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy from typing import Dict, List -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['EmbeddingHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py similarity index 97% rename from colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py index 9b41f37be..0d28875c7 100644 --- a/colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py @@ -1,9 +1,9 @@ import operator from functools import reduce import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler -from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size __all__ = ['LayerNormHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/operator_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py similarity index 98% rename from colossalai/auto_parallel/solver/op_handler/operator_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py index c0d5e9143..79f72d8d5 100644 --- a/colossalai/auto_parallel/solver/op_handler/operator_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py @@ -7,7 +7,7 @@ from typing import Dict, List from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec from .._utils import generate_resharding_costs, generate_sharding_spec -from colossalai.auto_parallel.solver.constants import * +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * from ..sharding_strategy import StrategiesVector diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py similarity index 95% rename from colossalai/auto_parallel/solver/op_handler/reshape_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py index 53ff73a90..2fc619c52 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py @@ -1,11 +1,11 @@ import colorsys from .operator_handler import OperatorHandler from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from copy import deepcopy import math -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler import warnings import torch from ..constants import INFINITY_COST diff --git a/colossalai/auto_parallel/solver/op_handler/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py similarity index 100% rename from colossalai/auto_parallel/solver/op_handler/strategy_generator.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py diff --git a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py similarity index 93% rename from colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py index 205923fa0..57ad9e262 100644 --- a/colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py @@ -2,15 +2,15 @@ import operator from functools import reduce import warnings import torch -from colossalai.auto_parallel.solver.constants import INFINITY_COST -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy from typing import Dict, List import math -from colossalai.auto_parallel.solver._utils import exception_handler +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler __all__ = ['UnaryElementwiseHandler'] diff --git a/colossalai/auto_parallel/solver/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py similarity index 97% rename from colossalai/auto_parallel/solver/op_handler/where_handler.py rename to colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py index 17a8df55c..dddd91786 100644 --- a/colossalai/auto_parallel/solver/op_handler/where_handler.py +++ b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py @@ -2,13 +2,13 @@ import operator from functools import reduce import warnings import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from .operator_handler import OperatorHandler from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec from copy import deepcopy from typing import Dict, List -from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding +from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding __all__ = ['WhereHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/options.py b/colossalai/auto_parallel/tensor_shard/deprecated/options.py new file mode 100644 index 000000000..2d34f5c64 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/options.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +__all__ = ['SolverOptions'] + + +@dataclass +class SolverOptions: + """ + SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search. + """ + fast: bool = False diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py new file mode 100644 index 000000000..d468c858e --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py @@ -0,0 +1,91 @@ +from copy import deepcopy +from dataclasses import dataclass +from abc import ABC, abstractmethod +from enum import Enum +import operator +import torch +from functools import reduce + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from typing import Dict, List, Union, Tuple, Any +from torch.fx.node import Node +from .constants import * + +__all__ = ['ShardingStrategy', 'StrategiesVector'] + + +@dataclass +class ShardingStrategy: + ''' + ShardingStrategy is a structure containing sharding strategies of inputs and output of this node + and costs information using in solver. + + Argument: + name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'. + output_sharding_spec(ShardingSpec): ShardingSpec of the output node. + compute_cost(float): Computation cost to complete this strategy.(default to 0) + communication_cost(float): Communication cost to complete this strategy.(default to 0) + memory_cost(float): Memory cost of the output node using this strategy.(default to 0) + resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy.(default to None) + input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes. + ''' + + name: str + # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor. + output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]] + compute_cost: float = 0. + communication_cost: float = 0. + memory_cost: float = 0. + resharding_costs: Dict[Node, List[float]] = None + # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input. + # Therefore, we could process them at the specific op(operator.getitem) + input_shardings: List[ShardingSpec] = None + + +class StrategiesVector(list): + ''' + Each node in fx graph will have a corresponding StrategiesVector, to store all the possible + strategies of the node. + + Argument: + node (Node): node for which the list of sharding strategies are generated. + ''' + + def __init__(self, node: Node): + super().__init__() + self.node = node + # fetch its input and output nodes + # TODO: placeholder input nodes + self.predecessor_nodes = list(node._input_nodes.keys()) + if self.node.op == 'output': + self.predecessor_nodes = list(node._input_nodes.keys())[:1] + self.successor_nodes = list(node.users.keys()) + + def check_merge(self): + merge_label = False + if self.node.op == 'call_module': + target = self.node.target + root_module = self.node.graph.owning_module + submod = root_module.get_submodule(target) + submod_type = type(submod) + # merge elementwise module node into source nodes + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if submod_type in ELEMENTWISE_MODULE_OP: + merge_label = True + + if self.node.op == 'call_function': + # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec. + if self.node.target in ELEMENTWISE_FUNC_OP: + merge_label = True + # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case. + if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1: + merge_label = True + # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated. + if self.node.target in RESHAPE_FUNC_OP: + merge_label = True + + return merge_label diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py new file mode 100644 index 000000000..2167e6ac2 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py @@ -0,0 +1,467 @@ +import warnings + +import time +import numpy as np +import multiprocessing +from torch.fx.node import Node +from torch.fx.graph import Graph +from .graph_analysis import GraphAnalyser +from .cost_graph import CostGraph +from .strategies_constructor import StrategiesConstructor +from typing import Dict +from .constants import INFINITY_COST +try: + import pulp + from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus +except: + warnings.warn(f'please install the pulp') + +__all___ = ['Solver'] + + +class Solver: + + def __init__(self, + graph: Graph, + strategies_constructor: StrategiesConstructor, + cost_graph: CostGraph, + graph_analyser: GraphAnalyser, + memory_budget: float = -1.0, + solution_numbers: int = 1, + memory_increasing_coefficient: float = 1.3): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.graph = graph + self.strategies_constructor = strategies_constructor + self.cost_graph = cost_graph + self.graph_analyser = graph_analyser + self.leaf_strategies = self.strategies_constructor.leaf_strategies + self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies] + self.strategy_map = self.strategies_constructor.strategy_map + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + self.liveness_list = self.graph_analyser.liveness_analysis() + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + + def _recover_merged_node_strategy(self): + ''' + During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node. + Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged + node. + ''' + for node_index, node in enumerate(self.nodes): + if node.strategies_vector.check_merge(): + # the merged node has only one input, and its strategies follow the input sharding strategy + input_strategies_vector = node.args[0].strategies_vector + input_best_strategy_index = self.last_s_val[node_index - 1] + input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec + for strategy_index, strategy in enumerate(node.strategies_vector): + if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence: + self.last_s_val[node_index] = strategy_index + break + + def _generate_node_index_dict(self) -> Dict[Node, int]: + node_index_dict = {} + for index, strategies_vector in enumerate(self.leaf_strategies): + node_index_dict[strategies_vector.node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append(self.cost_graph.node_lens[node]) + strategies_len = np.array(strategies_len) + + # prepare following_nodes + following_nodes = self.cost_graph.following_dict + index_following_nodes = {} + for src, target in following_nodes.items(): + src_index = self.node_index_dict[src] + target_index = self.node_index_dict[target] + index_following_nodes[src_index] = target_index + following_nodes = index_following_nodes + for index in range(node_nums): + if index not in following_nodes: + following_nodes[index] = -1 + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + + # prepare liveness_set + liveness_set = self.liveness_list + + # omit alias_set now + alias_set = None + alias_convert_costs = None + + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + extra_node_costs = self.cost_graph.extra_node_costs + for strategies_vector in self.leaf_strategies: + node = strategies_vector.node + for index, strategy in enumerate(strategies_vector): + compute_costs.append(strategy.compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + if node in extra_node_costs: + origin_communication_cost = strategy.communication_cost + extra_node_cost = extra_node_costs[node][index] + communication_cost = origin_communication_cost + extra_node_cost + communication_costs.append(communication_cost) + else: + communication_costs.append(strategy.communication_cost) + # temporarily we just consider the forward memory cost + memory_cost = strategy.memory_cost + if isinstance(memory_cost, tuple): + memory_costs.append(memory_cost[0]) + else: + memory_costs.append(memory_cost) + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array(memory_costs) + + # omit initial value for nodes + s_init_np = None + + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + alias_convert_costs, + s_init_np=None): + """ + Call the solver with serialized arguments. + """ + + tic = time.time() + + for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + v = [] + pt = 0 + + c = [] + d = [] + m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + num_nodes += 1 + s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary")) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + num_edges += 1 + e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary")) + assert len(e[idx]) == len(r[idx]) + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) + + prob += obj + + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + for liveness_stage in liveness_set: + mem = 0 + for live_variable in liveness_stage.unique_live_vars: + node_index = self.node_index_dict[live_variable.node] + mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index]))) + prob += mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col] + + # (h) + ###################### + # omit alias set now # + ###################### + + # alias_set = set() + # for (idx, (i, j)) in enumerate(A): + # R = len(s[i]) # noqa + # C = len(s[j]) # noqa + # if (i, j) in alias_set: + # raise ValueError(f"Duplicated edges: {(i, j)}") + + # alias_set.add((i, j)) + # alias_set.add((j, i)) + + # for row in range(len(s[i])): + # for col in range(len(s[j])): + # if v[idx][row * C + col] > 0.5: + # prob += s[i][row] + s[j][col] <= 1 + + verbose = True + + msg = verbose + time_limit = 600 + assert "COIN_CMD" in pulp.listSolvers( + onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'") + + solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count()) + # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float(objective) if objective is not None else -1.0 + if verbose: + print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t" + f"Time: {time.time() - tic}") + print(f"#nodes: {num_nodes}, #edges: {num_edges}") + + if prob.status in [pulp.LpStatusInfeasible]: + raise RuntimeError("Cannot run the function under the given memory budget. " + "Please increase the memory budget.") + + # Get and check results + s_val = np.full((node_nums,), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E),), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + self._recover_merged_node_strategy() + self.last_objective = objective + + if objective > INFINITY_COST: + warnings.warn("Detect unexpected behaviors in the auto-sharding pass.") + + return self.last_s_val, e_val, self.last_objective, status + + def call_solver_serialized_args(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py new file mode 100644 index 000000000..528d37977 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py @@ -0,0 +1,423 @@ +from torch.fx import Graph, Node +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from .options import SolverOptions +from .sharding_strategy import ShardingStrategy, StrategiesVector +from .op_handler import * +from .constants import * +from copy import deepcopy +import math +import torch +import operator +from typing import Dict, List +from ._utils import generate_sharding_spec, generate_resharding_costs +import builtins + +__all__ = ['StrategiesConstructor'] + + +class StrategiesConstructor: + """ + StrategiesConstructor is used to construct the parallelization plan for the model execution. + + Args: + graph (Graph): a Graph object used for analysis and strategy generation. + device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster. + solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching. + """ + + def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions): + self.graph = graph + assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.solver_options = solver_options + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + + for strategy in remove_list: + strategies_vector.remove(strategy) + + def _is_bcast_matmul(self, node): + is_bcast_matmul = False + if node.target is torch.matmul and len(node.args) == 2: + lhs_data = node.args[0]._meta_data + rhs_data = node.args[1]._meta_data + if lhs_data.dim() >= 3 and rhs_data.dim() >= 3: + is_bcast_matmul = True + return is_bcast_matmul + + def build_strategies_and_cost(self): + for node in self.nodes: + strategies_vector = StrategiesVector(node) + input_nodes_len = 0 + for check_node in strategies_vector.predecessor_nodes: + if isinstance(check_node._meta_data, torch.Tensor): + input_nodes_len += 1 + # input_nodes_len = len(strategies_vector.predecessor_nodes) + # placeholder node + if node.op == 'placeholder': + # For placeholder nodes, if solver_options.fast is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the placeholder node. + # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. + + if self.solver_options.fast: + # create sharding strategy for placeholder + name = 'Replica Placeholder' + dim_partition_dict = {} + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_placeholder = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_placeholder) + + # get_attr node + if node.op == 'get_attr': + # Same as placeholder nodes, if solver_options.fast is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the get_attr node. + # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. + if self.solver_options.fast: + # create sharding strategy for get_attr + name = 'Replica Attribute' + dim_partition_dict = {} + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_attribute) + + # call_module node + if node.op == 'call_module': + + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + + # conv module + if submod_type in CONV_MODULE_OP: + # use ConvHandler to create sharding strategies for conv module node + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) + conv_handler.register_strategy() + + # linear module + elif submod_type in LINEAR_MODULE_OP: + # use DotHandler to create sharding strategies for linear module node + dot_handler = DotHandler(node, self.device_mesh, strategies_vector) + dot_handler.register_strategy() + + # element-wise module + elif submod_type in ELEMENTWISE_MODULE_OP: + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + # BatchNormNd module + elif submod_type in BATCHNORM_MODULE_OP: + # create sharding strategy for element-wise module + norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector) + norm_handler.register_strategy() + # for strategy in norm_handler.strategies_vector: + # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}') + # assert False + + # MaxPool module + elif submod_type in POOL_MODULE_OP: + # TODO: add sharding constraints on image dimension + # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension + + # create sharding strategy for element-wise module + assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.' + input_node = strategies_vector.predecessor_nodes[0] + # For element-wise module, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise module. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict) + + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = node._meta_data.numel() + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # embedding module + elif submod_type in EMBEDDING_MODULE_OP: + embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector) + embedding_handler.register_strategy() + + # layernorm module + elif submod_type in LAYERNORM_MODULE_OP: + layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector) + layernorm_handler.register_strategy() + # other module + else: + raise RuntimeError(f'{submod_type} module is NOT supported now.') + + # call_function node + if node.op == 'call_function': + target = node.target + # conv function + if target in CONV_FUNC_OP: + # use ConvHandler to create sharding strategies for conv node + # TODO: the operator_handler does NOT support function node processing now. + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector) + conv_handler.register_strategy() + + # linear function + elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node): + # use DotHandler to create sharding strategies for linear node + # TODO: the operator_handler does NOT support function node processing now. + linear_handler = DotHandler(node, self.device_mesh, strategies_vector) + linear_handler.register_strategy() + + # where function + elif target == torch.where: + if input_nodes_len == 1: + # both of x and y are scalar + pass + + elif input_nodes_len == 2: + # one of x or y is type of scalar + pass + + else: + # general case + where_handler = WhereHandler(node, self.device_mesh, strategies_vector) + where_handler.register_strategy() + + # reshape function + elif target in RESHAPE_FUNC_OP: + # use ReshapeHandler to create sharding strategies for rehsape node + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + + # element-wise function + elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1): + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + # bcast op + elif target in BCAST_FUNC_OP: + if isinstance(node._meta_data, torch.Tensor): + bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector) + bcast_op_handler.register_strategy() + + # torch.var_mean + elif target == torch.var_mean: + dim = node.kwargs['dim'] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + entire_shape_input = input_sharding_spec.entire_shape + dim_partition_dict_input = input_sharding_spec.dim_partition_dict + name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})' + if dim in dim_partition_dict_input: + # We need to make the action dimension in replicate status + dim_partition_dict_for_input = deepcopy(dim_partition_dict_input) + dim_partition_dict_for_input.pop(dim) + new_input_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_input, + dim_partition_dict=dim_partition_dict_for_input) + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [new_input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[new_input_sharding_spec]) + + else: + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partion_dict=dim_partition_dict_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + + strategies_vector.append(sharding_strategy) + + # operator.getitem + elif target == operator.getitem: + index = node.args[1] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + if isinstance(strategy.output_sharding_spec, ShardingSpec): + input_sharding_spec = strategy.output_sharding_spec + else: + input_sharding_spec = strategy.output_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' + dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) + entire_shape_output = deepcopy(input_sharding_spec.entire_shape) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec], + index=index) + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_tensor_node] = [ + cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[strategy.output_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # torch.arange function + elif target == torch.arange: + name = f'FULLY REPLICATED ARANGE' + entire_shape_output = node._meta_data.shape + dim_partition_dict_for_output = {} + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + memory_cost = node._meta_data.numel() + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=0, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy) + + # op list to be processed to support gpt2 + elif target in (builtins.getattr, operator.le, torch.addmm): + pass + # other function + else: + raise RuntimeError(f'{target} function is NOT supported now.') + + # call_method node + if node.op == 'call_method': + method = getattr(node.args[0]._meta_data.__class__, node.target) + if method in (torch.Tensor.size,): + pass + elif method in ELEMENTWISE_METHOD_OP: + unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector) + unary_elementwise_handler.register_strategy() + + elif method in RESHAPE_METHOD_OP: + reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector) + reshape_handler.register_strategy() + # print(strategies_vector) + # if len(strategies_vector) == 0: + # print(node) + # assert False + else: + raise RuntimeError(f'{method} function is NOT supported now.') + + # output node + if node.op == 'output': + if self.solver_options.fast: + # create sharding strategy for output + name = 'Replica Output' + input_nodes = strategies_vector.predecessor_nodes + input_sharding_specs = [] + for input_node in input_nodes: + dim_partition_dict_for_input = {} + entire_shape = input_node._meta_data.shape + sharding_spec = ShardingSpec(self.device_mesh, + entire_shape, + dim_partition_dict=dim_partition_dict_for_input) + input_sharding_specs.append(sharding_spec) + + dim_partition_dict = {} + output_sharding_spec = input_sharding_specs + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes, + input_sharding_specs) + + # clear the resharding cost for the output node + # TODO: we may remove this in final version + for prev_node, resharding_cost_list in resharding_costs.items(): + resharding_costs[prev_node] = [0] * len(resharding_cost_list) + + sharding_strategy_attribute = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=tuple(input_sharding_specs)) + strategies_vector.append(sharding_strategy_attribute) + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) diff --git a/tests/test_auto_parallel/test_broadcast.py b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py similarity index 93% rename from tests/test_auto_parallel/test_broadcast.py rename to tests/test_auto_parallel/test_tensor_shard/test_broadcast.py index 09c698ee0..1a9279a78 100644 --- a/tests/test_auto_parallel/test_broadcast.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_broadcast.py @@ -1,5 +1,5 @@ import torch -from colossalai.auto_parallel.solver.op_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape +from colossalai.auto_parallel.solver.node_handler.broadcast import is_broadcastable, get_broadcast_shape, recover_sharding_spec_for_broadcast_shape from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py similarity index 96% rename from tests/test_auto_parallel/test_cost_graph.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py index 5b5bcb5d9..a244329c0 100644 --- a/tests/test_auto_parallel/test_cost_graph.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py @@ -6,9 +6,9 @@ import pytest from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from copy import deepcopy diff --git a/tests/test_auto_parallel/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py similarity index 95% rename from tests/test_auto_parallel/test_batch_norm_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py index 4869ecbfa..2d3e71551 100644 --- a/tests/test_auto_parallel/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py @@ -6,8 +6,8 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_bcast_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py similarity index 93% rename from tests/test_auto_parallel/test_bcast_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py index 023d3ac15..f83d7ceb7 100644 --- a/tests/test_auto_parallel/test_bcast_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py @@ -3,8 +3,8 @@ from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.auto_parallel.solver.options import SolverOptions -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_bcast_matmul.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py similarity index 89% rename from tests/test_auto_parallel/test_bcast_matmul.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py index 72f0a3afe..27120f0ba 100644 --- a/tests/test_auto_parallel/test_bcast_matmul.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py @@ -3,8 +3,8 @@ from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.auto_parallel.solver.options import SolverOptions -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py similarity index 95% rename from tests/test_auto_parallel/test_conv_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py index c66e85883..09afbdef1 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py @@ -6,8 +6,8 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py similarity index 95% rename from tests/test_auto_parallel/test_dot_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py index 856e462de..e901b84a3 100644 --- a/tests/test_auto_parallel/test_dot_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py @@ -6,8 +6,8 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py similarity index 88% rename from tests/test_auto_parallel/test_layer_norm_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py index afab3934f..40e227cb5 100644 --- a/tests/test_auto_parallel/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py @@ -2,13 +2,13 @@ import torch from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.auto_parallel.solver import sharding_strategy +from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.op_handler.layer_norm_handler import LayerNormHandler -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py similarity index 90% rename from tests/test_auto_parallel/test_reshape_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py index ac9cfad6d..c895dff4e 100644 --- a/tests/test_auto_parallel/test_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py @@ -3,8 +3,8 @@ from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.auto_parallel.solver.options import SolverOptions -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py similarity index 92% rename from tests/test_auto_parallel/test_where_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py index 6eb8c75b5..1fd8fff2d 100644 --- a/tests/test_auto_parallel/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py @@ -3,8 +3,8 @@ from torch.fx import GraphModule import torch.nn as nn import pytest -from colossalai.auto_parallel.solver.options import SolverOptions -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py similarity index 87% rename from tests/test_auto_parallel/test_shape_consistency_pass.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py index 6cb46c1de..b15497d2c 100644 --- a/tests/test_auto_parallel/test_shape_consistency_pass.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py @@ -9,15 +9,15 @@ from colossalai.initialize import launch from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.fx.tracer.tracer import ColoTracer from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass -from colossalai.auto_parallel.solver import Solver -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions class ConvModel(nn.Module): diff --git a/tests/test_auto_parallel/test_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py similarity index 85% rename from tests/test_auto_parallel/test_solver.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py index ce8d2ba09..df640050a 100644 --- a/tests/test_auto_parallel/test_solver.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py @@ -6,12 +6,12 @@ import pytest from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.cost_graph import CostGraph -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser from copy import deepcopy -from colossalai.auto_parallel.solver import Solver -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated import Solver +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions class ConvModel(nn.Module): diff --git a/tests/test_auto_parallel/test_solver_with_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py similarity index 83% rename from tests/test_auto_parallel/test_solver_with_gpt.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py index 9001d2ce3..ac0ce1b87 100644 --- a/tests/test_auto_parallel/test_solver_with_gpt.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py @@ -4,17 +4,17 @@ import torch.nn as nn import pytest from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph from copy import deepcopy -from colossalai.auto_parallel.solver import Solver +from colossalai.auto_parallel.tensor_shard.deprecated import Solver import transformers -from colossalai.auto_parallel.solver.constants import * -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions BATCH_SIZE = 8 SEQ_LENGHT = 8 diff --git a/tests/test_auto_parallel/test_solver_with_mlp.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py similarity index 84% rename from tests/test_auto_parallel/test_solver_with_mlp.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py index 5a850eee7..7ba63951d 100644 --- a/tests/test_auto_parallel/test_solver_with_mlp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py @@ -4,17 +4,17 @@ import torch.nn as nn import pytest from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.cost_graph import CostGraph +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph from copy import deepcopy -from colossalai.auto_parallel.solver import Solver +from colossalai.auto_parallel.tensor_shard.deprecated import Solver from torchvision.models import resnet34, resnet50 -from colossalai.auto_parallel.solver.constants import * -from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.constants import * +from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions class MLP(torch.nn.Module): diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py similarity index 89% rename from tests/test_auto_parallel/test_strategies_constructor.py rename to tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py index ef8dbde03..7886de5ad 100644 --- a/tests/test_auto_parallel/test_strategies_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py @@ -6,11 +6,11 @@ import pytest from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST +from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor -from colossalai.auto_parallel.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from copy import deepcopy diff --git a/tests/test_auto_parallel/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py similarity index 100% rename from tests/test_auto_parallel/test_liveness_analysis.py rename to tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py diff --git a/tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py similarity index 97% rename from tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index e0d98c758..3bfb5e875 100644 --- a/tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2 import BatchNormModuleHandler +from colossalai.auto_parallel.solver.node_handler.batch_norm_handler import BatchNormModuleHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py similarity index 98% rename from tests/test_auto_parallel/test_node_handler/test_bmm_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index dfe50a0e7..ad45ee3f1 100644 --- a/tests/test_auto_parallel/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -2,7 +2,7 @@ import pytest import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import BMMFunctionHandler +from colossalai.auto_parallel.solver.node_handler.dot_handler import BMMFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py similarity index 98% rename from tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index 56bae372a..28643cdf0 100644 --- a/tests/test_auto_parallel/test_node_handler/test_conv_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler +from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvModuleHandler, ConvFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py similarity index 95% rename from tests/test_auto_parallel/test_node_handler/test_getitem_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 08877eb25..97c03eae0 100644 --- a/tests/test_auto_parallel/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.getitem_handler import GetItemHandler -from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler +from colossalai.auto_parallel.solver.node_handler.getitem_handler import GetItemHandler +from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_layer_norm_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py similarity index 96% rename from tests/test_auto_parallel/test_node_handler/test_layer_norm_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index 9bb7882bd..3c942bd5e 100644 --- a/tests/test_auto_parallel/test_node_handler/test_layer_norm_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.layer_norm_handler_v2 import LayerNormModuleHandler +from colossalai.auto_parallel.solver.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py similarity index 97% rename from tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index 059e661da..4870a2ce1 100644 --- a/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler -from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy_V2 +from colossalai.auto_parallel.solver.node_handler.dot_handler import LinearModuleHandler, LinearFunctionHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec @@ -83,7 +83,7 @@ def test_linear_module_handler(): assert 'RS1 = RR x RS1' in strategy_name_list for strategy in strategies_vector: - strategy: ShardingStrategy_V2 + strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') @@ -164,7 +164,7 @@ def test_linear_function_handler(): assert 'RS1 = RR x RS1' in strategy_name_list for strategy in strategies_vector: - strategy: ShardingStrategy_V2 + strategy: ShardingStrategy input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') diff --git a/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py similarity index 95% rename from tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index c0dd02722..63ca627d4 100644 --- a/tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.normal_pooling_handler import NormPoolingHandler +from colossalai.auto_parallel.solver.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh import pytest diff --git a/tests/test_auto_parallel/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py similarity index 95% rename from tests/test_auto_parallel/test_node_handler/test_output_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 48fd4d2c5..e16bd6ba9 100644 --- a/tests/test_auto_parallel/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.output_handler import OuputHandler +from colossalai.auto_parallel.solver.node_handler.output_handler import OuputHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py similarity index 95% rename from tests/test_auto_parallel/test_node_handler/test_placeholder_handler.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 68f9aff14..66f31635c 100644 --- a/tests/test_auto_parallel/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.placeholder_handler import PlacehodlerHandler +from colossalai.auto_parallel.solver.node_handler.placeholder_handler import PlacehodlerHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py similarity index 88% rename from tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py index 8ae352778..3249d10ee 100644 --- a/tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler -from colossalai.auto_parallel.solver.op_handler.reshape_handler_v2 import ReshapeHandler_V2 +from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.solver.node_handler.reshape_handler import ReshapeHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -48,9 +48,9 @@ def test_reshape_handler(): strategies_vector=conv_strategies_vector) conv_handler.register_strategy(compute_resharding_cost=False) setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - reshape_handler = ReshapeHandler_V2(node=reshape_node, - device_mesh=device_mesh, - strategies_vector=reshape_strategies_vector) + reshape_handler = ReshapeHandler(node=reshape_node, + device_mesh=device_mesh, + strategies_vector=reshape_strategies_vector) reshape_handler.register_strategy(compute_resharding_cost=False) diff --git a/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py similarity index 87% rename from tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index 7d8f6f10b..f79c81197 100644 --- a/tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -2,8 +2,8 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2 -from colossalai.auto_parallel.solver.op_handler.conv_handler_v2 import ConvFunctionHandler +from colossalai.auto_parallel.solver.node_handler.unary_elementwise_handler import UnaryElementwiseHandler +from colossalai.auto_parallel.solver.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh @@ -50,9 +50,9 @@ def test_elementwise_handler(): strategies_vector=conv_strategies_vector) conv_handler.register_strategy(compute_resharding_cost=False) setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) - relu_handler = UnaryElementwiseHandler_V2(node=relu_mod_node, - device_mesh=device_mesh, - strategies_vector=relu_strategies_vector) + relu_handler = UnaryElementwiseHandler(node=relu_mod_node, + device_mesh=device_mesh, + strategies_vector=relu_strategies_vector) relu_handler.register_strategy(compute_resharding_cost=False) diff --git a/tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py similarity index 97% rename from tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index 8e039472f..a81f1695d 100644 --- a/tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -2,7 +2,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear import torch import torch.nn as nn from colossalai.fx import ColoTracer, ColoGraphModule -from colossalai.auto_parallel.solver.op_handler.where_handler_v2 import WhereHandler +from colossalai.auto_parallel.solver.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh diff --git a/tests/test_auto_parallel/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py similarity index 93% rename from tests/test_auto_parallel/test_solver_with_resnet_v2.py rename to tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index 21dcbad56..a8e90ba0b 100644 --- a/tests/test_auto_parallel/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -7,10 +7,10 @@ from colossalai.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.device.device_mesh import DeviceMesh -from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor_V2 -from colossalai.auto_parallel.solver.cost_graph import CostGraph_V2 +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph from copy import deepcopy -from colossalai.auto_parallel.solver.solver import Solver_V2 +from colossalai.auto_parallel.solver.solver import Solver from torchvision.models import resnet34, resnet50 from colossalai.auto_parallel.solver.constants import * from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser @@ -60,12 +60,12 @@ def test_cost_graph(): graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() solver_options = SolverOptions(fast=True) - strategies_constructor = StrategiesConstructor_V2(graph, device_mesh, solver_options) + strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() - cost_graph = CostGraph_V2(strategies_constructor.leaf_strategies) + cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph.simplify_graph() - solver = Solver_V2(gm.graph, strategies_constructor, cost_graph, graph_analyser) + solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) ret = solver.call_solver_serialized_args() print(ret[0])