[autoparallel] add torch.nn.ReLU metainfo (#1868)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input
This commit is contained in:
Boyuan Yao
2022-11-16 23:12:31 +08:00
committed by GitHub
parent 8c66a1d0aa
commit 7c7921f71b
9 changed files with 151 additions and 8 deletions

View File

@@ -20,7 +20,7 @@ __all__ = ['linear_meta_info']
@meta_register.register(torch.nn.Linear)
def linear_meta_info(*args) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""torch.nn.Linear meta info generator
The atens graph of torch.nn.Linear with bias is
graph():