diff --git a/colossalai/auto_parallel/solver/op_handler/where_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/where_handler_v2.py new file mode 100644 index 000000000..3dbe3f463 --- /dev/null +++ b/colossalai/auto_parallel/solver/op_handler/where_handler_v2.py @@ -0,0 +1,87 @@ +import torch +from .node_handler import NodeHandler +from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData, StrategiesVector +from ..strategy import WhereGenerator, StrategyGenerator_V2 +from .broadcast import recover_sharding_spec_for_broadcast_shape +from typing import List, Dict +from .registry import operator_registry +import operator +import copy + +__all__ = ['WhereHandler'] + + +@operator_registry.register(torch.where) +class WhereHandler(NodeHandler): + """ + A WhereHandler which deals with the sharding strategies for torch.where. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator_V2]: + logical_op_data_mapping, _ = self.get_operation_data_mapping() + generators = [] + generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh)) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + physical_condition_operand = OperationData(name=str(self.node.args[0]), + type=OperationDataType.ARG, + data=self.node.args[0]._meta_data) + physical_x_operand = OperationData(name=str(self.node.args[1]), + type=OperationDataType.ARG, + data=self.node.args[1]._meta_data) + physical_y_operand = OperationData(name=str(self.node.args[2]), + type=OperationDataType.ARG, + data=self.node.args[2]._meta_data) + physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) + physical_mapping = { + "condition": physical_condition_operand, + "x": physical_x_operand, + "y": physical_y_operand, + "output": physical_output + } + logical_shape_for_all = self.node._meta_data.shape + logical_mapping = {} + for key, physical_operand in physical_mapping.items(): + logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand, + logical_shape_for_all) + + return logical_mapping, physical_mapping + + def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape): + logical_operand = copy.deepcopy(physical_operand) + logical_operand.logical_shape = target_shape + return logical_operand + + def register_strategy(self, compute_resharding_cost: bool = False) -> StrategiesVector: + """ + Register different sharding strategies for the current node. + """ + strategy_generators = self.get_strategy_generator() + + for generator in strategy_generators: + strategies = generator.generate() + strategies_vector = map(self.post_process, strategies) + # compute the resharding costs based on the previous node + # strategies if specified + if compute_resharding_cost: + strategies = list(map(self.update_resharding_cost, strategies)) + self.strategies_vector.extend(strategies) + + self.strategies_vector = list(strategies_vector) + return self.strategies_vector + + def post_process(self, strategy: ShardingStrategy_V2): + logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping() + for key in logical_op_data_mapping.keys(): + logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]] + logical_shape = logical_op_data_mapping[key].logical_shape + physical_shape = physical_op_data_mapping[key].logical_shape + physical_sharding_spec = recover_sharding_spec_for_broadcast_shape(logical_sharding_spec, logical_shape, + physical_shape) + strategy.sharding_specs.pop(logical_op_data_mapping[key]) + strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec + strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}" + return strategy diff --git a/colossalai/auto_parallel/solver/strategy/__init__.py b/colossalai/auto_parallel/solver/strategy/__init__.py index b65da3f16..a7bffb2e8 100644 --- a/colossalai/auto_parallel/solver/strategy/__init__.py +++ b/colossalai/auto_parallel/solver/strategy/__init__.py @@ -5,11 +5,12 @@ from .batch_norm_generator import BatchNormStrategyGenerator from .unary_elementwise_generator import UnaryElementwiseGenerator from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator, TensorTupleStrategyGenerator from .layer_norm_generator import LayerNormGenerator +from .where_generator import WhereGenerator from .reshape_generator import ReshapeGenerator __all__ = [ 'StrategyGenerator_V2', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator', 'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator', 'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', - 'TensorTupleStrategyGenerator', 'LayerNormGenerator', 'ReshapeGenerator' + 'TensorTupleStrategyGenerator', 'LayerNormGenerator', "WhereGenerator", 'ReshapeGenerator' ] diff --git a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py index 97130a0b9..3049b5b4c 100644 --- a/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/solver/strategy/layer_norm_generator.py @@ -163,7 +163,7 @@ class LayerNormGenerator(StrategyGenerator_V2): def generate(self): ''' - Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. + Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. ''' strategy_list = [] input_data_dim = len(self.op_data["input"].logical_shape) diff --git a/colossalai/auto_parallel/solver/strategy/where_generator.py b/colossalai/auto_parallel/solver/strategy/where_generator.py new file mode 100644 index 000000000..bceb3c42a --- /dev/null +++ b/colossalai/auto_parallel/solver/strategy/where_generator.py @@ -0,0 +1,99 @@ +import operator +from functools import reduce +from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from .strategy_generator import StrategyGenerator_V2, FollowingStrategyGenerator +from typing import List +from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding +import copy + +__all__ = ['WhereGenerator'] + + +class WhereGenerator(StrategyGenerator_V2): + """ + WhereGenerator is a generic class to generate strategies for Where operation. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy_V2): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy_V2): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'condition': self._compute_size_in_bytes(strategy, "condition"), + 'x': self._compute_size_in_bytes(strategy, "x"), + 'y': self._compute_size_in_bytes(strategy, "y"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = condition + x + y + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0) + + # compute bwd cost incurred + # bwd_cost = condition_grad + x_grad + y_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def _generate_strategy_with_dim_partition(self, dim_partition): + dim_partition_dict_mapping = { + "condition": dim_partition, + "x": dim_partition, + "y": dim_partition, + "output": dim_partition + } + + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}' + communication_action_mapping = {} + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + return strategy + + def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length): + dim_partition_list = [] + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length)) + dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length)) + # append {} for non_split case + dim_partition_list.append({}) + + return dim_partition_list + + def generate(self): + ''' + Generate every possible strategies for a where node, and record all strategies into the strategies_vector. + ''' + strategy_list = [] + + dimension_length = len(self.op_data["output"].logical_shape) + dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length) + for dim_partition in dim_partition_list: + strategy = self._generate_strategy_with_dim_partition(dim_partition) + strategy_list.append(strategy) + + for strategy in strategy_list: + self.update_communication_cost(strategy) + self.update_compute_cost(strategy) + self.update_memory_cost(strategy) + + return strategy_list diff --git a/tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py new file mode 100644 index 000000000..eee32a9ea --- /dev/null +++ b/tests/test_auto_parallel/test_node_handler/test_where_handler_v2.py @@ -0,0 +1,85 @@ +from colossalai.fx.tracer.meta_patch.patched_module import linear +import torch +import torch.nn as nn +from colossalai.fx import ColoTracer, ColoGraphModule +from colossalai.auto_parallel.solver.op_handler.where_handler_v2 import WhereHandler +from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh + + +class ConvModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, condition, x, y): + output = torch.where(condition, x, y) + return output + + +def test_where_handler(): + model = ConvModel() + tracer = ColoTracer() + # graph(): + # %condition : torch.Tensor [#users=1] = placeholder[target=condition] + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %y : torch.Tensor [#users=1] = placeholder[target=y] + # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {}) + # return where + graph = tracer.trace(model, + meta_args={ + "condition": torch.rand(4, 4, 64, 64).to('meta'), + "x": torch.rand(4, 1, 64, 64).to('meta'), + "y": torch.rand(1, 4, 64, 64).to('meta') + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + where_node = list(graph.nodes)[3] + strategies_vector = StrategiesVector(where_node) + + # build handler + handler = WhereHandler(node=where_node, device_mesh=device_mesh, strategies_vector=strategies_vector) + + # check operation data mapping + mapping, _ = handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.logical_shape is not None + assert op_data.data is not None + + assert mapping['condition'].name == "condition" + assert mapping['condition'].data.is_meta + assert mapping['condition'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['condition'].type == OperationDataType.ARG + assert mapping['condition'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['x'].name == "x" + assert mapping['x'].data.is_meta + assert mapping['x'].data.shape == torch.Size([4, 1, 64, 64]) + assert mapping['x'].type == OperationDataType.ARG + assert mapping['x'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['y'].name == "y" + assert mapping['y'].data.is_meta + assert mapping['y'].data.shape == torch.Size([1, 4, 64, 64]) + assert mapping['y'].type == OperationDataType.ARG + assert mapping['y'].logical_shape == torch.Size([4, 4, 64, 64]) + + assert mapping['output'].name == "where" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([4, 4, 64, 64]) + assert mapping['output'].type == OperationDataType.OUTPUT + + handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + # 4*3 + 4*3/2*2 + 1 + assert len(strategy_name_list) == 25 + + +if __name__ == '__main__': + test_where_handler()