diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 826225a62..27957ca63 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node @@ -7,6 +7,7 @@ from torch.fx.node import Node from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, + ShardingSpec, ShardingStrategy, StrategiesVector, TrainCycleItem, @@ -52,12 +53,14 @@ class NodeHandler(ABC): node_name = str(node) # get the current sharding spec generated by this node handler - # TODO: we need to check this in future - if not isinstance(node._meta_data, torch.Tensor): + # we will not compute the resharding costs for the node not counted in the strategy. + # And the node with tuple or list output need to be handled below. + node_in_strategy = [op_data.name for op_data in strategy.sharding_specs.keys()] + if str(node) not in node_in_strategy: continue + op_data = strategy.get_op_data_by_name(node_name) current_sharding_spec = strategy.sharding_specs[op_data] - # get the sharding specs for this node generated # in its own node handler assert hasattr(node, 'strategies_vector'), \ @@ -68,23 +71,64 @@ class NodeHandler(ABC): ] # create data structrure to store costs - if op_data not in resharding_costs: + if node not in resharding_costs: resharding_costs[node] = [] + def _compute_resharding_cost( + prev_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], current_sharding_spec: Union[ShardingSpec, + List[ShardingSpec]], + data: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]) -> TrainCycleItem: + """ + This is a helper function to compute the resharding cost for a specific strategy of a node. + """ + if prev_sharding_spec is None: + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(prev_sharding_spec, ShardingSpec): + if isinstance(data, torch.nn.parameter.Parameter): + # we won't compute the resharding cost for the parameters, + # since the parameters will be sharded before runtime and + # not converted during runtime. + return TrainCycleItem(fwd=0, bwd=0, total=0) + elif isinstance(data, torch.Tensor): + dtype = data.dtype + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + _, _, consistency_cost = shape_consistency_manager.shape_consistency( + prev_sharding_spec, current_sharding_spec) + + resharding_cost = TrainCycleItem(fwd=consistency_cost["forward"] * size_per_elem_bytes, + bwd=consistency_cost["backward"] * size_per_elem_bytes, + total=consistency_cost["total"] * size_per_elem_bytes) + return resharding_cost + else: + # This raise is used to check if we have missed any type of data. + # It could be merged into Parameter branch, which means we won't handle + # non-tensor arguments. + raise ValueError(f'Unsupported data type {type(data)}') + else: + assert isinstance(prev_sharding_spec, (tuple, list)), \ + f'prev_sharding_spec should be in type of ShardingSpec, List[ShardingSpec], \ + or Tuple[ShardingSpec], but got {type(prev_sharding_spec)}' + + fwd_cost = 0 + bwd_cost = 0 + total_cost = 0 + for index, (prev_sharding_spec_item, + current_sharding_spec_item) in enumerate(zip(prev_sharding_spec, + current_sharding_spec)): + item_cost = _compute_resharding_cost(prev_sharding_spec_item, current_sharding_spec_item, + data[index]) + fwd_cost += item_cost.fwd + bwd_cost += item_cost.bwd + total_cost += item_cost.total + resharding_cost = TrainCycleItem(fwd=fwd_cost, bwd=bwd_cost, total=total_cost) + return resharding_cost + # for each sharding spec generated by the predecessor's node handler # compute the resharding cost to switch to the sharding spec generated # by the current node handler for prev_sharding_spec in prev_sharding_specs: - if op_data.type == OperationDataType.PARAM: - resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0) - else: - dtype = op_data.data.dtype - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - _, _, resharding_cost = shape_consistency_manager.shape_consistency( - prev_sharding_spec, current_sharding_spec) - resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes, - bwd=resharding_cost["backward"] * size_per_elem_bytes, - total=resharding_cost["total"] * size_per_elem_bytes) + resharding_cost = _compute_resharding_cost(prev_sharding_spec, current_sharding_spec, op_data.data) resharding_costs[node].append(resharding_cost) strategy.resharding_costs = resharding_costs return strategy diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index ca17fbaf4..6d68521aa 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -68,32 +68,41 @@ class StrategyGenerator(ABC): Args: mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary. + + Notes: + The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data. + However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as + list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned. """ results = {} for op_data_name, dim_partition_dict in mapping.items(): if op_data_name in self.op_data: op_data = self.op_data[op_data_name] - if isinstance(op_data.data, tuple): - for data in op_data.data: - assert isinstance( - data, torch.Tensor), 'We cannot create a ShardingSpec object from a non-tensor object.' - sharding_spec = [] - for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict): + + def _to_sharding_spec( + data: any, logical_shape: any, + dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]: + """ + This is a recursive function to convert the dim partition dict to a ShardingSpec object. + """ + if isinstance(data, torch.Tensor): dim_size = len(logical_shape) - dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element) - sharding_spec_element = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=logical_shape, - dim_partition_dict=dim_partition_dict_element) - sharding_spec.append(sharding_spec_element) - else: - assert isinstance( - op_data.data, torch.Tensor - ), f'op_data.data should be a torch.Tensor or Tuple[torch.Tensor], but got {type(op_data.data)}' - dim_size = len(op_data.logical_shape) - dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) - sharding_spec = ShardingSpec(device_mesh=self.device_mesh, - entire_shape=op_data.logical_shape, - dim_partition_dict=dim_partition_dict) + dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict) + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=logical_shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec + elif isinstance(data, (list, tuple)): + sharding_spec = [] + for data_element, logical_shape_element, dim_partition_dict_element in zip( + data, logical_shape, dim_partition_dict): + sharding_spec.append( + _to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element)) + return sharding_spec + else: + return None + + sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict) results[op_data_name] = sharding_spec return results @@ -285,6 +294,5 @@ class OutputStrategyGenerator(StrategyGenerator): def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh, predecessor_nodes: List[Node]): - self.op_data = operation_data_mapping - self.device_mesh = device_mesh + super().__init__(operation_data_mapping, device_mesh) self.predecessor_nodes = predecessor_nodes diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index a70c87a13..efe484917 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -44,10 +44,20 @@ class OperationData: def __post_init__(self): # if no logical shape is specified, use the data shape as the logical shape if self.logical_shape is None: - if isinstance(self.data, torch.Tensor): - self.logical_shape = self.data.shape - elif isinstance(self.data, tuple): - self.logical_shape = tuple([getattr(d, 'shape', None) for d in self.data]) + + def _infer_logical_shape(data: any): + """ + This function is used to infer the logical shape of the data. + """ + if isinstance(data, torch.Tensor): + return data.shape + elif isinstance(data, (tuple, list)): + data_type = type(data) + return data_type([_infer_logical_shape(d) for d in data]) + else: + return None + + self.logical_shape = _infer_logical_shape(self.data) def __repr__(self) -> str: return f'OperationData(name={self.name}, type={self.type})' @@ -216,8 +226,6 @@ class StrategiesVector(list): # fetch its input and output nodes # TODO: placeholder input nodes self.predecessor_nodes = list(node._input_nodes.keys()) - if self.node.op == 'output': - self.predecessor_nodes = list(node._input_nodes.keys())[:1] self.successor_nodes = list(node.users.keys()) def check_merge(self): diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py index abddbf2b0..f1509af56 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py +++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py @@ -1,13 +1,14 @@ -from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST import torch +from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST + class CostGraph: ''' A graph data structure to simplify the edge cost graph. It has two main functions: 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. - 2. To reduce the searching space, we merge computationally-trivial operators, such as + 2. To reduce the searching space, we merge computationally-trivial operators, such as element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will be given by the StrategiesVector depending on the type of target node and following nodes. @@ -66,8 +67,6 @@ class CostGraph: children_nodes = [node for node in strategies_vector.successor_nodes] setattr(dst_node, 'parents', parent_nodes) setattr(dst_node, 'children', children_nodes) - # self._remove_invalid_node(dst_node, 'parents') - # self._remove_invalid_node(dst_node, 'children') if self.simplify and strategies_vector.check_merge(): for followed_node in strategies_vector.predecessor_nodes: @@ -79,14 +78,14 @@ class CostGraph: def merge_node(self, src_node, dst_node): ''' To merge dst_node into src_node, we need to do it in following steps: - + 1. For each strategy in dst_node, we need to pick an appropriate strategy - of src_node to merge, it is important because the logical resharding costs - between the parents node of src_node and merged node depend on the src_node + of src_node to merge, it is important because the logical resharding costs + between the parents node of src_node and merged node depend on the src_node strategies dispatching. For example, for the graph 0->1->2, after merging node 1 into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] x represents the picking strategy of node 1 merged into node 2 strategy 0. - + 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs contains two parts, one is resharding costs between src_node strategy and dst_node strategy, another is the origin extra costs in src_node strategy. @@ -98,10 +97,9 @@ class CostGraph: src_node(Node): The node will be merged into dst_node. dst_node(Node): The node to integrate src_node. ''' - src_node_index = dst_node.parents.index(src_node) # build merge_map merge_map = {} - for src_index, strategy in enumerate(src_node.strategies_vector): + for src_index, _ in enumerate(src_node.strategies_vector): min_cost = INFINITY_COST lowest_cost_index = -1 for dst_index, dst_strategy in enumerate(dst_node.strategies_vector): @@ -139,7 +137,6 @@ class CostGraph: for i in range(self.node_lens[src_node]): for j in range(self.node_lens[child_node]): dst_strate_index = merge_map[i] - # dst_strategy = dst_node.strategies_vector[dst_strate_index] edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)] if new_node_pair not in self.edge_costs: self.edge_costs[new_node_pair] = edge_cost diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index b934ef2ea..6342feeee 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -1,3 +1,4 @@ +import builtins import math import operator from copy import deepcopy @@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import ( operator_registry, ) from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector +from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec from colossalai.device.device_mesh import DeviceMesh from .options import DataloaderOption, SolverOptions @@ -49,10 +51,6 @@ class StrategiesConstructor: name_checklist = [] remove_list = [] for strategy in strategies_vector: - if strategy is None: - print(strategies_vector.node.name) - print(strategies_vector) - assert False if strategy.name not in name_checklist: name_checklist.append(strategy.name) else: @@ -64,10 +62,33 @@ class StrategiesConstructor: """ This method is to build the strategy vector for each node in the computation graph. """ + + def _check_no_strategy_for_node(node): + if node.op in ('placeholder', 'get_attr', 'output'): + return False + + def _check_no_strategy_for_data(data): + label = True + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (tuple, list)): + for d in data: + label = label and _check_no_strategy_for_data(d) + return label + + return _check_no_strategy_for_data(node._meta_data) + + no_strategy_node = [] for node in self.nodes: strategies_vector = StrategiesVector(node) + + print(node) + if _check_no_strategy_for_node(node): + no_strategy_node.append(node) + pass + # placeholder node - if node.op == 'placeholder': + elif node.op == 'placeholder': if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED: placeholder_option = 'distributed' else: @@ -80,7 +101,7 @@ class StrategiesConstructor: placeholder_handler.register_strategy() # get_attr node - if node.op == 'get_attr': + elif node.op == 'get_attr': getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector) getattr_handler.register_strategy() @@ -114,10 +135,19 @@ class StrategiesConstructor: output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option) output_handler.register_strategy() - if len(strategies_vector) <= 0: - print(node.name) - assert len(strategies_vector) > 0 self.remove_duplicated_strategy(strategies_vector) setattr(node, 'strategies_vector', strategies_vector) self.leaf_strategies.append(strategies_vector) self.strategy_map[node] = strategies_vector + + # remove no strategy nodes + remove_list = [] + for strategies_vector in self.leaf_strategies: + if len(strategies_vector) == 0: + remove_list.append(strategies_vector.node) + + for node in remove_list: + if node.strategies_vector in self.leaf_strategies: + self.leaf_strategies.remove(node.strategies_vector) + if node in self.strategy_map: + self.strategy_map.pop(node) diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index 043147b9f..63c48195d 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -6,7 +6,7 @@ from .broadcast import ( recover_sharding_spec_for_broadcast_shape, ) from .factory import generate_resharding_costs, generate_sharding_spec -from .misc import check_sharding_spec_validity, ignore_sharding_exception +from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map from .sharding import ( enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, @@ -19,5 +19,5 @@ __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands' + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/misc.py b/colossalai/auto_parallel/tensor_shard/utils/misc.py index 967847390..9e402dab7 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/misc.py +++ b/colossalai/auto_parallel/tensor_shard/utils/misc.py @@ -1,11 +1,12 @@ import functools +from typing import Any, Callable, Dict, List, Tuple, Type, Union import torch from colossalai.logging import get_dist_logger from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -__all__ = ['ignore_sharding_exception'] +__all__ = ['ignore_sharding_exception', 'pytree_map'] def ignore_sharding_exception(func): @@ -70,3 +71,27 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens # make sure the entire shape matches the physical tensor shape assert sharding_spec.entire_shape == tensor.shape, \ f'The entire_shape of the sharding spec {sharding_spec.entire_shape} does not match the tensor shape {tensor.shape}' + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types (:class: `type | tuple[type]`): types to determine the type to process + map_all (:class: `bool`): if map_all is True, then any type of element will use fn + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj