mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[analyzer] a minimal implementation of static graph analyzer (#2852)
* [hotfix] meta tensor default device. * [siu] add experimental submodules to main branch. * [siu] * [siu] * [analyzer] init. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [analyzer] readme. * [test] add test. * Update symbolic_trace.py * mark skip tests. * try except. * try except. * try except. * s * init * init * fix * skip * skip --------- Co-authored-by: Daniel Shao <superdainiu@MININT-PVARVID.fareast.corp.microsoft.com> Co-authored-by: Daniel Shao <superdainiu@Daniels-Mac.local>
This commit is contained in:
88
colossalai/_analyzer/_subclasses/_monkey_patch.py
Normal file
88
colossalai/_analyzer/_subclasses/_monkey_patch.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
__all__ = [
|
||||
"_TorchFactoryMethod",
|
||||
"_TorchOverrideableFactoryMethod",
|
||||
"_TorchNonOverrideableFactoryMethod",
|
||||
"_TensorPropertyMethod",
|
||||
"_DistCommMethod",
|
||||
"_AliasATen",
|
||||
"_InplaceATen",
|
||||
"_MaybeInplaceATen",
|
||||
]
|
||||
|
||||
_TorchOverrideableFactoryMethod = [
|
||||
"empty",
|
||||
"eye",
|
||||
"full",
|
||||
"ones",
|
||||
"rand",
|
||||
"randn",
|
||||
"zeros",
|
||||
]
|
||||
|
||||
_TorchNonOverrideableFactoryMethod = [
|
||||
"arange",
|
||||
"finfo",
|
||||
"linspace",
|
||||
"logspace",
|
||||
"randint",
|
||||
"randperm",
|
||||
"tensor",
|
||||
]
|
||||
|
||||
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
|
||||
|
||||
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
|
||||
|
||||
_DistCommMethod = [
|
||||
"all_gather",
|
||||
"all_reduce",
|
||||
"all_to_all",
|
||||
"broadcast",
|
||||
"gather",
|
||||
"reduce",
|
||||
"reduce_scatter",
|
||||
"scatter",
|
||||
]
|
||||
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
Reference in New Issue
Block a user