mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[autoparallel] new metainfoprop based on metainfo class (#2179)
* [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver * [autoparallel] new metainfoprop to combine SPMD solver and checkpoint solver * [autoparallel] modify placeholder handler * [autoparallel] modify metainfoprop * [autoparallel] fix function typo * [autoparallel] fix placeholder handler
This commit is contained in:
@@ -111,18 +111,27 @@ class StrategiesConstructor:
|
||||
submod_type = type(submod)
|
||||
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
|
||||
# call_function node
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
|
||||
# call_method node
|
||||
elif node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
|
||||
handler.register_strategy()
|
||||
# attach metainfo_vector to node
|
||||
if hasattr(handler, 'metainfo_vector'):
|
||||
setattr(node, 'metainfo_vector', handler.metainfo_vector)
|
||||
|
||||
# output node
|
||||
elif node.op == 'output':
|
||||
|
Reference in New Issue
Block a user