mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@@ -49,40 +50,45 @@ _DistCommMethod = [
|
||||
"scatter",
|
||||
]
|
||||
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
||||
# TODO: dive deep here
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
||||
_AliasATen = [
|
||||
aten.detach.default,
|
||||
aten.detach_.default,
|
||||
aten.t.default,
|
||||
aten.transpose.int,
|
||||
aten.view.default,
|
||||
aten._unsafe_view.default,
|
||||
aten._reshape_alias.default,
|
||||
]
|
||||
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
_InplaceATen = [
|
||||
aten.add_.Tensor,
|
||||
aten.add_.Scalar,
|
||||
aten.sub_.Tensor,
|
||||
aten.sub_.Scalar,
|
||||
aten.mul_.Tensor,
|
||||
aten.mul_.Scalar,
|
||||
aten.div_.Tensor,
|
||||
aten.div_.Scalar,
|
||||
aten.pow_.Tensor,
|
||||
aten.pow_.Scalar,
|
||||
]
|
||||
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
||||
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
||||
_MaybeInplaceATen = [
|
||||
aten.diagonal.default,
|
||||
aten.expand.default,
|
||||
aten.select.int,
|
||||
aten.slice.Tensor,
|
||||
aten.split.Tensor,
|
||||
aten.squeeze.default,
|
||||
aten.permute.default,
|
||||
aten.unsqueeze.default,
|
||||
aten.as_strided.default,
|
||||
]
|
||||
else:
|
||||
_AliasATen = []
|
||||
_InplaceATen = []
|
||||
_MaybeInplaceATen = []
|
||||
|
Reference in New Issue
Block a user