ColossalAI/colossalai/_analyzer/_subclasses/_monkey_patch.py
Super Daniel fff98f06ed
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device.

* [siu] add experimental submodules to main branch.

* [siu]

* [siu]

* [analyzer] init.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [analyzer] readme.

* [test] add test.

* Update symbolic_trace.py

* mark skip tests.

* try except.

* try except.

* try except.

* s

* init

* init

* fix

* skip

* skip

---------

Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com>
Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
2023-03-10 13:21:05 +08:00

89 lines
1.8 KiB
Python

import torch
import torch.distributed as dist
aten = torch.ops.aten
__all__ = [
"_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod",
"_TorchNonOverrideableFactoryMethod",
"_TensorPropertyMethod",
"_DistCommMethod",
"_AliasATen",
"_InplaceATen",
"_MaybeInplaceATen",
]
_TorchOverrideableFactoryMethod = [
"empty",
"eye",
"full",
"ones",
"rand",
"randn",
"zeros",
]
_TorchNonOverrideableFactoryMethod = [
"arange",
"finfo",
"linspace",
"logspace",
"randint",
"randperm",
"tensor",
]
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
_DistCommMethod = [
"all_gather",
"all_reduce",
"all_to_all",
"broadcast",
"gather",
"reduce",
"reduce_scatter",
"scatter",
]
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]