mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autockpt] make it work. (#2257)
This commit is contained in:
@@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
|
||||
return meta_info
|
||||
|
||||
|
||||
def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo:
|
||||
"""
|
||||
This method is used to construct `MetaInto` for shape consistency node
|
||||
"""
|
||||
@@ -62,8 +62,8 @@ def _runtime_apply_meta_info(node: Node, original_sharding_spec_dict, sharding_s
|
||||
# extract node index and user node index
|
||||
args = node.args
|
||||
node_index, user_node_index = args[3], args[4]
|
||||
origin_sharding_spec, target_sharding_spec = original_sharding_spec_dict[node_index], sharding_spec_dict[
|
||||
node_index][user_node_index]
|
||||
origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][
|
||||
user_node_index]
|
||||
|
||||
return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
@@ -98,16 +98,16 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M
|
||||
return meta_info
|
||||
|
||||
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, original_sharding_spec_dict: Dict,
|
||||
comm_actions_dict: Dict):
|
||||
def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_dict: Dict,
|
||||
comm_actions_dict: Dict) -> GraphModule:
|
||||
"""
|
||||
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
|
||||
"""
|
||||
for node in gm.graph.nodes:
|
||||
if node.target == runtime_apply:
|
||||
setattr(node, 'best_metainfo',
|
||||
_runtime_apply_meta_info(node, original_sharding_spec_dict, sharding_spec_dict))
|
||||
setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict))
|
||||
elif node.target == runtime_comm_spec_apply:
|
||||
setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict))
|
||||
else:
|
||||
pass
|
||||
return gm
|
||||
|
Reference in New Issue
Block a user