[autoparallel] add binary elementwise metainfo for auto parallel (#2058)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print
This commit is contained in:
Boyuan Yao
2022-12-04 15:18:51 +08:00
committed by GitHub
parent 4b40fbd743
commit 616da17fab
9 changed files with 164 additions and 11 deletions

View File

@@ -1,5 +1,12 @@
import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace operations
INPLACE_MODULE = [nn.ReLU]
# list of operations that do not save forward activations
NO_SAVE_ACTIVATION = [torch.add, torch.sub, operator.add, operator.sub]

View File

@@ -1,4 +1,5 @@
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
from .linear import *
from .norm import *

View File

@@ -0,0 +1,65 @@
from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
from colossalai.fx.profiler.opcount import flop_mapping
from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
@meta_register.register(BCAST_FUNC_OP)
def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta information generator for binary elementwise operations
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
this behavior, it is critical for better memory estimation.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(
activation=activation_size(fwd_in_args),
parameter=param_mem_cost,
)
# total cost
total_mem_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
# store fwd_in
fwd_in = fwd_in_args
return compute_cost, memory_cost, fwd_in

View File

@@ -13,7 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import INPLACE_MODULE
from .constants import INPLACE_MODULE, NO_SAVE_ACTIVATION
from .registry import meta_register
__all__ = ['MetaInfo']
@@ -35,6 +35,9 @@ class MetaInfo:
# list of input tensors
self.fwd_in: list[OperationData]
# bool type to indicate whether the function will save forward activation
self.save_fwd_in: bool
# sharding strategy
self._strategy = strategy
@@ -95,10 +98,16 @@ class MetaInfo:
try:
# module
meta_func = meta_register.get(self._target.__class__)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target.__class__ not in NO_SAVE_ACTIVATION
except:
# function
meta_func = meta_register.get(self._target)
# check whether the target in the module list that we don't need to save activation
self.save_fwd_in = self._target not in NO_SAVE_ACTIVATION
# construct args for meta_func
args = [self.compute_sharded_tensor(k, v) for k, v in self._strategy.sharding_specs.items()]