[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
This commit is contained in:
Super Daniel
2022-09-07 11:21:04 +08:00
committed by GitHub
parent 7d49e7b2db
commit 4f59693207
38 changed files with 776 additions and 263 deletions

View File

@@ -0,0 +1,85 @@
import operator
from functools import reduce
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function
def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return reduce(operator.mul, other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return reduce(operator.mul, input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
max_dim = max(dim_input, dim_other)
final_shape = []
for i in range(max_dim):
in_i = input.shape[i] if i < dim_input else 1
ot_i = other.shape[i] if i < dim_other else 1
if in_i > ot_i:
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = reduce(operator.mul, final_shape)
return flops, 0
@meta_profiler_function.register(torch.add)
@meta_profiler_function.register(torch.eq)
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op =
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@meta_profiler_function.register(torch.abs)
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
flops = input.numel() * 3
macs = 0
return flops, macs