From b904748210232bd6e1278cac0a022a2015ee084b Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Tue, 3 Jan 2023 20:28:01 +0800 Subject: [PATCH] [autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293) * [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OP --- .../meta_registry/binary_elementwise_ops.py | 11 +++-- .../tensor_shard/node_handler/node_handler.py | 42 +++++++++++-------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 15c3063b7..281a92c0d 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs """ - input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] + input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) # construct forward args for flop mapping - fwd_in_args = [input_op_data.data, other_op_data.data] + fwd_in_args = [opdata.data for opdata in input_op_data] fwd_out_args = [output_op_data.data] # calculate cost # calculate compute cost # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case - fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) + fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args) bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - param_mem_cost = activation_size( - [arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM]) + param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM]) fwd_mem_cost = MemoryCost( - activation=activation_size([input_op_data.data, output_op_data.data]), + activation=activation_size(output_op_data.data), parameter=param_mem_cost, ) bwd_mem_cost = MemoryCost( diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index af3cb5810..78dc58c90 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo +from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector @@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler): """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() - metainfo_vector = [] - for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) - strategy.compute_cost = metainfo.compute_cost - strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + # Currently we haven't patched all the torch functions and modules, so if the target + # is not patched, we will use the default cost model to compute the cost. + # TODO: patch all torch functions and modules to make it clean + if meta_register.has(target.__class__) or meta_register.has(target): + metainfo_vector = [] + for strategy in self.strategies_vector: + metainfo = MetaInfo(strategy, target) + strategy.compute_cost = metainfo.compute_cost + strategy.memory_cost = metainfo.memory_cost + metainfo_vector.append(metainfo) - # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + # attach metainfos to the handler + setattr(self, "metainfo_vector", metainfo_vector) return self.strategies_vector