[autoparallel] refactor runtime pass (#2644)

* [autoparallel] refactor runtime pass

* add unit test

* polish
This commit is contained in:
YuliangLiu0306
2023-02-15 10:36:19 +08:00
committed by GitHub
parent 89f8975fb8
commit cb2c6a2415
5 changed files with 352 additions and 214 deletions

View File

@@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
torch.nn.ReLU,
torch.nn.Softmax,
]
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
# This list could be extended if any other method has the same
# argument style as view and reshape.
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]