[autoparallel] new metainfoprop based on metainfo class (#2179)

* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver

* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver

* [autoparallel] modify placeholder handler

* [autoparallel] modify metainfoprop

* [autoparallel] fix function typo

* [autoparallel] fix placeholder handler
This commit is contained in:
Boyuan Yao
2022-12-28 13:35:08 +08:00
committed by GitHub
parent 78509124d3
commit d0bc5a1b34
4 changed files with 185 additions and 0 deletions

View File

@@ -111,18 +111,27 @@ class StrategiesConstructor:
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
setattr(node, 'metainfo_vector', handler.metainfo_vector)
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
setattr(node, 'metainfo_vector', handler.metainfo_vector)
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
setattr(node, 'metainfo_vector', handler.metainfo_vector)
# output node
elif node.op == 'output':