[autoparallel] new metainfoprop based on metainfo class (#2179)

* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver

* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver

* [autoparallel] modify placeholder handler

* [autoparallel] modify metainfoprop

* [autoparallel] fix function typo

* [autoparallel] fix placeholder handler
This commit is contained in:
Boyuan Yao
2022-12-28 13:35:08 +08:00
committed by GitHub
parent 78509124d3
commit d0bc5a1b34
4 changed files with 185 additions and 0 deletions

View File

@@ -79,6 +79,10 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# attach the corresponding metainfo if node has the attribute `metainfo_vector`
if hasattr(node, 'metainfo_vector'):
setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes