From 30e50c8b4a6bb0bf25c97ad520c06c04c813188f Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 27 Sep 2022 12:06:25 +0800 Subject: [PATCH] [autoparallel] implemented all matmul strategy generator (#1650) --- .../solver/op_handler/dot_handler_v2.py | 20 +- .../solver/op_handler/node_handler.py | 6 +- .../auto_parallel/solver/sharding_strategy.py | 6 + .../strategy/matmul_strategy_generator.py | 379 ++++++++++++++++-- .../solver/strategy/strategy_generator.py | 17 +- .../test_linear_handler_v2.py | 82 +++- .../test_torchrec_model/test_deepfm_model.py | 3 - .../test_torchrec_model/test_dlrm_model.py | 3 - 8 files changed, 440 insertions(+), 76 deletions(-) rename tests/test_auto_parallel/{ => test_node_handler}/test_linear_handler_v2.py (56%) diff --git a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py index c5b858fe2..e591559b2 100644 --- a/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py +++ b/colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py @@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler): if op_data.name == "weight": assert op_data.logical_shape != op_data.data.shape dim_partition_dict = sharding_spec.dim_partition_dict + # switch first and last dim of the linear module weight - dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] + first_dim_partition = dim_partition_dict.pop(-1, None) + last_dim_partition = dim_partition_dict.pop(0, None) + + if first_dim_partition: + dim_partition_dict[0] = first_dim_partition + + if last_dim_partition: + dim_partition_dict[-1] = last_dim_partition # re-init the sharding spec sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) @@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler): if op_data.name == str(self.node.args[1]): assert op_data.logical_shape != op_data.data.shape dim_partition_dict = sharding_spec.dim_partition_dict + # switch first and last dim of the linear module weight - dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] + first_dim_partition = dim_partition_dict.pop(-1, None) + last_dim_partition = dim_partition_dict.pop(0, None) + + if first_dim_partition: + dim_partition_dict[0] = first_dim_partition + + if last_dim_partition: + dim_partition_dict[-1] = last_dim_partition # re-init the sharding spec sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) diff --git a/colossalai/auto_parallel/solver/op_handler/node_handler.py b/colossalai/auto_parallel/solver/op_handler/node_handler.py index a509664fc..f18f61b5d 100644 --- a/colossalai/auto_parallel/solver/op_handler/node_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/node_handler.py @@ -33,12 +33,12 @@ class NodeHandler(ABC): Register different sharding strategies for the current node. """ strategy_generators = self.get_strategy_generator() - operand_mapping = self.get_operation_data_mapping() for generator in strategy_generators: - strategies = generator.generate(operand_mapping) + strategies = generator.generate() self.strategies_vector.extend(strategies) - self.strategies_vector = map(self.post_process, self.strategies_vector) + strategies_vector = map(self.post_process, self.strategies_vector) + self.strategies_vector = list(strategies_vector) return self.strategies_vector def post_process(self, strategy: ShardingStrategy_V2): diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index c63aae863..4df256568 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -75,6 +75,12 @@ class OperationData: if self.logical_shape is None: self.logical_shape = self.data.shape + def __repr__(self) -> str: + return f'OperationData(name={self.name}, type={self.type})' + + def __hash__(self) -> int: + return hash(f'{self.name}-{self.type}') + @dataclass class TrainCycleItem: diff --git a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py index d1b561cb5..89bab5f44 100644 --- a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py @@ -1,7 +1,4 @@ -from cmath import log -from distutils.log import Log import operator -import torch from functools import reduce from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern @@ -9,17 +6,148 @@ from .strategy_generator import StrategyGenerator_V2 from typing import List -class DotProductStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass +class MatMulStrategyGenerator(StrategyGenerator_V2): + """ + MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. + The operation data is defined as `output = input x other + bias`. + """ + + @property + def has_bias(self): + return 'bias' in self.op_data + + def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'other': self._compute_size_in_bytes(strategy, "other"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + if self.has_bias: + bias_size = self._compute_size_in_bytes(strategy, "bias") + size_mapping['bias'] = bias_size + + # compute fwd cost incurred + # fwd_cost = input + other + bias + output + fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bias_grad + bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + 0) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost -class MatVecStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass +class DotProductStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 + + def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: + sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() + fwd_compute_cost = sharded_input_shape[0] + bwd_compute_cost = sharded_input_shape * 2 + compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + bwd=bwd_compute_cost, + total=fwd_compute_cost + bwd_compute_cost) + return compute_cost + + def no_split(self): + name = f'R = R dot R' + dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + communication_action_mapping = {} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_one_dim(self, mesh_dim): + name = f'R = S{mesh_dim} dot S{mesh_dim}' + + # get sharding spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {"output": output_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def generate(self) -> List[ShardingStrategy_V2]: + strategy_list = [] + + # do not split dimensions for dot product + # R = R dot R + strategy_list.append(self.no_split()) + + # split two tensors in the same dimensions + # S = S dot S + strategy_list.append(self.split_one_dim(0)) + strategy_list.append(self.split_one_dim(1)) + + return strategy_list -class LinearProjectionStrategyGenerator(StrategyGenerator_V2): +class MatVecStrategyGenerator(MatMulStrategyGenerator): + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 + + def no_split(self): + name = "R = R x R" + dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def split_input_batch(self, mesh_dim): + name = f'S{mesh_dim}R = S{mesh_dim}R x R' + + # get sharding spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication action + other_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {'other': other_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def generate(self) -> List[ShardingStrategy_V2]: + strategy_list = [] + + # no split + strategy_list.append(self.no_split()) + + # split the batch dim for the first tensor only + strategy_list.append(self.split_input_batch(0)) + strategy_list.append(self.split_input_batch(1)) + + return strategy_list + + +class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: # C = AB @@ -39,23 +167,6 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost - def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: - input_size = self._compute_size_in_bytes(strategy, "input") - other_size = self._compute_size_in_bytes(strategy, "input") - - if "bias" in self.op_data: - bias_size = self._compute_size_in_bytes(strategy, "bias") - else: - bias_size = 0 - output_size = self._compute_size_in_bytes(strategy, "output") - - fwd_mem_cost = MemoryCost(activation=output_size, parameter=other_size + bias_size) - bwd_mem_cost = MemoryCost(activation=input_size + other_size + bias_size, parameter=other_size) - total_mem_cost = MemoryCost(activation=input_size + 2 * output_size + bias_size, - parameter=other_size + bias_size) - memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) - strategy.memory_cost = memory_cost - def generate(self) -> List[ShardingStrategy_V2]: strategies = [] @@ -104,7 +215,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): 0: [mesh_dim_0] }, "other": { - self.dim_q: [mesh_dim_1] + -1: [mesh_dim_1] }, "bias": { -1: [mesh_dim_1] @@ -143,7 +254,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): -1: [mesh_dim_1] }, "other": { - self.dim_p: [mesh_dim_1] + 0: [mesh_dim_1] }, "bias": {}, "output": { @@ -159,7 +270,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): logical_process_axis=mesh_dim_0) output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["output"], - communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec} @@ -177,8 +288,8 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): -1: [mesh_dim_0] }, "other": { - self.dim_p: [mesh_dim_0], - self.dim_q: [mesh_dim_1] + 0: [mesh_dim_0], + -1: [mesh_dim_1] }, "bias": { -1: [mesh_dim_1] @@ -192,7 +303,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): # get communication actions output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], - communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0) input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], @@ -212,7 +323,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): -1: [mesh_dim] }, "other": { - self.dim_p: [mesh_dim] + 0: [mesh_dim] }, "bias": {}, "output": {}, @@ -223,7 +334,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): # get communication action output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], - communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim) communication_action_mapping = {'output': output_comm_spec} return self.get_sharding_strategy(name=name, @@ -237,7 +348,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): dim_partition_dict_mapping = { "input": {}, "other": { - self.dim_q: [mesh_dim] + -1: [mesh_dim] }, "bias": { -1: [mesh_dim] @@ -294,7 +405,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): -1: [mesh_dim_0, mesh_dim_1] }, "other": { - self.dim_p: [mesh_dim_0, mesh_dim_1] + 0: [mesh_dim_0, mesh_dim_1] }, "bias": {}, "output": {}, @@ -304,7 +415,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): # get communication action output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], - communication_pattern=CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) communication_action_mapping = {'output': output_comm_spec} @@ -319,7 +430,7 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): dim_partition_dict_mapping = { "input": {}, "other": { - self.dim_q: [mesh_dim_0, mesh_dim_1] + -1: [mesh_dim_0, mesh_dim_1] }, "bias": { -1: [mesh_dim_0, mesh_dim_1] @@ -359,6 +470,190 @@ class LinearProjectionStrategyGenerator(StrategyGenerator_V2): assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] -class BatchedMatMulStrategyGenerator(StrategyGenerator_V2): - """TODO: to be implemented""" - pass +class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): + """ + Generate sharding strategies for the batched matrix multiplication. + + A batched matrix multiplication can be viewed as + [b, i, k] x [b, k, j] -> [b, i, j] + """ + + def validate(self) -> bool: + input_op_data = self.op_data['input'] + other_op_data = self.op_data['other'] + assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 + + def split_one_batch_dim(self): + device_mesh_is_1d = True + if len(self.device_mesh.mesh_shape) == 1: + mesh_dim = 0 + elif 1 in self.device_mesh.mesh_shape: + mesh_dim = self.device_mesh.mesh_shape.index(1) + else: + device_mesh_is_1d = False + + if device_mesh_is_1d: + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + + # get sharding_spec + dim_partition_dict = { + "input": { + 0: [mesh_dim] + }, + "other": { + 0: [mesh_dim] + }, + "bias": {}, + "output": { + 0: [mesh_dim] + } + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + else: + return None + + def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0, mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0, mesh_dim_1] + } + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + communication_action_mapping = {} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0] + }, + "bias": {}, + "output": { + 0: mesh_dim_0, + -2: [mesh_dim_1] + } + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + other_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['other'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1) + communication_action_mapping = {'other': other_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0] + }, + "other": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "bias": { + -1: [mesh_dim_1] + }, + "output": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + } + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + input_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['input'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_1) + communication_action_mapping = {'input': input_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' + dim_partition_dict = { + "input": { + 0: [mesh_dim_0], + -1: [mesh_dim_1] + }, + "other": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + }, + "bias": {}, + "output": { + 0: [mesh_dim_0], + -2: [mesh_dim_1] + } + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) + + # get communication actions + output_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['output'], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1) + communication_action_mapping = {'output': output_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def generate(self) -> List[ShardingStrategy_V2]: + strategy_list = [] + + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh + strategy = self.split_one_batch_dim() + if strategy: + strategy_list.append(strategy) + + # split batch dim of two inputs and the i dim of the first tensor + # SbSi = SbSi x Sb + strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) + + # split batch dim of two inputs and the j of the second tensor + # SbSj = Sb x SbSj + strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) + strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) + + # split batch dim of two inputs and the k dim of two inputs + # Sb = SbSk x SbSk, need to all-reduce by k dim + strategy_list.append(self.split_batch_dim_both_contract(0, 1)) + strategy_list.append(self.split_batch_dim_both_contract(1, 0)) + + # split two batch dim + strategy_list.append(self.split_two_batch_dim(0, 1)) + strategy_list.append(self.split_two_batch_dim(1, 0)) + + return strategy_list diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py index 6b73ba0ce..e5221c755 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.device.device_mesh import DeviceMesh from typing import Dict, List, Union, Any -from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem +from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType class StrategyGenerator_V2(ABC): @@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC): self.op_data = operation_data_mapping self.device_mesh = device_mesh + def is_param(self, op_data_name): + other_data = self.op_data[op_data_name] + return other_data.type == OperationDataType.PARAM + def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], communication_action_mapping: Dict[str, CommSpec]): """ @@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC): Compute the communication cost involved in the forward and backward iteration. """ - comm_cost = TrainCycleItem(fwd=0, bwd=0) + comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0) def _compute_and_add(data: OperationData, comm_spec: CommSpec): num_ele_in_comm = comm_spec.get_comm_cost() @@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC): # TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. # it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, # so total cost is either for fwd or bwd. - if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: + if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: comm_cost.fwd += cost elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: comm_cost.fwd += cost @@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC): # check if communication action exists # if so, loop over each action and compute the cost of each action if strategy.communication_actions is not None: - for operand, comm_spec in strategy.communication_actions: + for operand, comm_spec in strategy.communication_actions.items(): _compute_and_add(operand, comm_spec) + # update the total cost + comm_cost.total = comm_cost.fwd + comm_cost.bwd + # update the communication cost attribute in-place strategy.communication_cost = comm_cost return strategy @@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC): pass @abstractmethod - def validate(self, *args, **kwargs) -> bool: + def validate(self) -> bool: """ Validate if the operands are of desired shape. If True, means this generator can be used for the current operation. diff --git a/tests/test_auto_parallel/test_linear_handler_v2.py b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py similarity index 56% rename from tests/test_auto_parallel/test_linear_handler_v2.py rename to tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py index 3b020cba3..a69ff986d 100644 --- a/tests/test_auto_parallel/test_linear_handler_v2.py +++ b/tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py @@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh def test_linear_module_handler(): - model = nn.Sequential(nn.Linear(10, 20).to('meta')) + model = nn.Sequential(nn.Linear(16, 32).to('meta')) tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) @@ -34,32 +34,55 @@ def test_linear_module_handler(): assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 10]) + assert mapping['input'].data.shape == torch.Size([4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 10]) + assert mapping['input'].logical_shape == torch.Size([4, 16]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([20, 10]) + assert mapping['other'].data.shape == torch.Size([32, 16]) assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([10, 20]) + assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['bias'].name == "bias" assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([20]) + assert mapping['bias'].data.shape == torch.Size([32]) assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([10, 20]) + assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "_0" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 20]) + assert mapping['output'].data.shape == torch.Size([4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT + strategies_vector = handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + def test_linear_function_handler(): - model = nn.Linear(10, 20).to('meta') + model = nn.Linear(16, 32).to('meta') tracer = ColoTracer() - graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) + graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) gm = ColoGraphModule(model, graph) physical_mesh_id = torch.arange(0, 4) @@ -77,27 +100,50 @@ def test_linear_function_handler(): assert mapping['input'].name == "input_1" assert mapping['input'].data.is_meta - assert mapping['input'].data.shape == torch.Size([4, 10]) + assert mapping['input'].data.shape == torch.Size([4, 16]) assert mapping['input'].type == OperationDataType.ARG - assert mapping['input'].logical_shape == torch.Size([4, 10]) + assert mapping['input'].logical_shape == torch.Size([4, 16]) assert mapping['other'].name == "weight" assert mapping['other'].data.is_meta - assert mapping['other'].data.shape == torch.Size([20, 10]) + assert mapping['other'].data.shape == torch.Size([32, 16]) assert mapping['other'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([10, 20]) + assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['bias'].name == "bias" assert mapping['bias'].data.is_meta - assert mapping['bias'].data.shape == torch.Size([20]) + assert mapping['bias'].data.shape == torch.Size([32]) assert mapping['bias'].type == OperationDataType.PARAM - assert mapping['other'].logical_shape == torch.Size([10, 20]) + assert mapping['other'].logical_shape == torch.Size([16, 32]) assert mapping['output'].name == "linear" assert mapping['output'].data.is_meta - assert mapping['output'].data.shape == torch.Size([4, 20]) + assert mapping['output'].data.shape == torch.Size([4, 32]) assert mapping['output'].type == OperationDataType.OUTPUT + strategies_vector = handler.register_strategy() + strategy_name_list = [val.name for val in strategies_vector] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + if __name__ == '__main__': test_linear_module_handler() diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index eb4761af8..0f1f294e4 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -1,6 +1,3 @@ -from curses import meta -from math import dist -from xml.dom import HierarchyRequestErr from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.meta_patch.patched_function import python_ops diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index fdf880866..5999a1abf 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -1,6 +1,3 @@ -from curses import meta -from math import dist -from xml.dom import HierarchyRequestErr from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.meta_patch.patched_function import python_ops