mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -11,6 +11,7 @@ from numbers import Number
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from .meta_tensor import MetaTensor
|
||||
@@ -403,134 +404,139 @@ def zero_flop_jit(*args):
|
||||
return 0
|
||||
|
||||
|
||||
flop_mapping = {
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
flop_mapping = {
|
||||
# gemm
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
aten.mm.default: matmul_flop_jit,
|
||||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
aten._convolution.default: conv_flop_jit,
|
||||
aten.convolution_backward.default: conv_backward_flop_jit,
|
||||
|
||||
# normalization
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
aten.native_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm.default: batchnorm_flop_jit,
|
||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||
}
|
||||
aten.avg_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool1d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
|
||||
aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
|
||||
aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
|
||||
aten.embedding.default: ewise_flop_counter(1, 0),
|
||||
}
|
||||
|
||||
ewise_flop_aten = [
|
||||
ewise_flop_aten = [
|
||||
# basic op
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div.Scalar,
|
||||
aten.div_.Scalar,
|
||||
aten.mul.Tensor,
|
||||
aten.mul.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.neg.default,
|
||||
aten.pow.Tensor_Scalar,
|
||||
aten.rsub.Scalar,
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
aten.add.Tensor,
|
||||
aten.add_.Tensor,
|
||||
aten.div.Tensor,
|
||||
aten.div_.Tensor,
|
||||
aten.div.Scalar,
|
||||
aten.div_.Scalar,
|
||||
aten.mul.Tensor,
|
||||
aten.mul.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.neg.default,
|
||||
aten.pow.Tensor_Scalar,
|
||||
aten.rsub.Scalar,
|
||||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
aten.sigmoid_backward.default,
|
||||
aten._softmax.default,
|
||||
aten._softmax_backward_data.default,
|
||||
aten.relu_.default,
|
||||
aten.relu.default,
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardswish_.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh.default,
|
||||
aten.hardtanh_.default,
|
||||
aten.hardtanh_backward.default,
|
||||
aten.hardsigmoid_backward.default,
|
||||
aten.hardsigmoid.default,
|
||||
aten.gelu.default,
|
||||
aten.gelu_backward.default,
|
||||
aten.silu.default,
|
||||
aten.silu_.default,
|
||||
aten.silu_backward.default,
|
||||
aten.sigmoid.default,
|
||||
aten.sigmoid_backward.default,
|
||||
aten._softmax.default,
|
||||
aten._softmax_backward_data.default,
|
||||
aten.relu_.default,
|
||||
aten.relu.default,
|
||||
aten.tanh.default,
|
||||
aten.tanh_backward.default,
|
||||
aten.threshold_backward.default,
|
||||
|
||||
# dropout
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
aten.native_dropout.default,
|
||||
aten.native_dropout_backward.default,
|
||||
|
||||
# distribution
|
||||
aten.bernoulli_.float,
|
||||
aten.bernoulli_.float,
|
||||
|
||||
# where
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||
aten.where.self,
|
||||
]
|
||||
for op in ewise_flop_aten:
|
||||
flop_mapping[op] = ewise_flop_counter(1, 0)
|
||||
|
||||
# fix-me: this will be removed in future
|
||||
zero_flop_aten = [
|
||||
aten.as_strided.default,
|
||||
aten.as_strided_.default,
|
||||
aten.cat.default,
|
||||
aten.clone.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.expand.default,
|
||||
aten.empty_like.default,
|
||||
aten.new_empty.default,
|
||||
aten.new_empty_strided.default,
|
||||
aten.ones_like.default,
|
||||
aten._reshape_alias.default,
|
||||
aten.select.int,
|
||||
aten.select_backward.default,
|
||||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
# fix-me: this will be removed in future
|
||||
zero_flop_aten = [
|
||||
aten.as_strided.default,
|
||||
aten.as_strided_.default,
|
||||
aten.cat.default,
|
||||
aten.clone.default,
|
||||
aten.copy_.default,
|
||||
aten.detach.default,
|
||||
aten.expand.default,
|
||||
aten.empty_like.default,
|
||||
aten.new_empty.default,
|
||||
aten.new_empty_strided.default,
|
||||
aten.ones_like.default,
|
||||
aten._reshape_alias.default,
|
||||
aten.select.int,
|
||||
aten.select_backward.default,
|
||||
aten.squeeze.dim,
|
||||
aten.slice.Tensor,
|
||||
aten.slice_backward.default,
|
||||
aten.split.Tensor,
|
||||
aten.permute.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten._to_copy.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.unbind.int,
|
||||
aten._unsafe_view.default,
|
||||
aten.view.default,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
else:
|
||||
flop_mapping = {}
|
||||
elementwise_flop_aten = {}
|
||||
zero_flop_aten = {}
|
||||
|
Reference in New Issue
Block a user