[autoparallel] Attach input, buffer and output tensor to MetaInfo class (#2162)

* [fx] metainfo class for auto parallel

* [fx] add unit test for linear metainfo

* [fx] fix bwd param for linear

* [fx] modify unit test

* [fx] modify unit test

* [fx] modify import

* [fx] modify import

* [fx] modify import

* [fx] move meta profiler to auto parallel

* [fx] add conv metainfo class

* [fx] restore profiler

* [fx] restore meta profiler

* [autoparallel] modify unit test

* [fx] modify unit test

* [autoparallel] add batchnorm metainfo class

* [autoparallel] fix batchnorm unit test function declaration

* [fx] restore profiler

* [fx] add relu metainfo class

* [fx] restore profiler

* [autoparallel] modify metainfo input

* [autoparallel] add pooling metainfo

* [autoparallel] add F.linear metainfo generator

* [autoparallel] add binary elementwise metainfo

* [fx] recover profiler

* [autoparallel] fix forward memory calculation

* [autoparallel] modify constants.py

* [autoparallel] remove redundant print

* [autoparallel] add F.conv metainfo

* [autoparallel] linear fix

* [autoparallel] memory estimation for communication actions

* [autoparallel] fix docstring

* [autoparallel] fix variables name

* [autoparallel] attach tensor to metainfo class

* [autoparallel] fix dangerous try except

* [autoparallel] attach memory cost to shape consistency node

* [autoparallel] attach shape consistency node's metainfo to the node

* [autoparallel] remove todo in shape consistency memory estimation

* [autoparallel] fix the annotation
This commit is contained in:
Boyuan Yao
2022-12-28 13:37:40 +08:00
committed by GitHub
parent d0bc5a1b34
commit 24246f7aa5
11 changed files with 118 additions and 44 deletions

View File

@@ -37,7 +37,7 @@ def _batchnorm_module_mem_test(rank, world_size, port):
# index of target node in computation graph
node_index = 1
# total number of target node strategies
strategy_number = 4
strategy_number = 9
mem_test_for_node_strategy(rank=rank,
model=model,
device_mesh=device_mesh,

View File

@@ -92,7 +92,7 @@ def _linear_function_mem_test(rank, world_size, port):
model=model,
device_mesh=device_mesh,
node_index=2,
strategy_number=23,
strategy_number=24,
input_args=[input],
meta_arg_names=["input"])