mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 06:52:46 +00:00
adapt codegen to prepose node
This commit is contained in:
parent
522f017418
commit
d309e9338b
@ -1198,7 +1198,7 @@ class FlowTracer(object):
|
|||||||
chunk_node_list.remove(n)
|
chunk_node_list.remove(n)
|
||||||
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
|
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
|
||||||
for i in non_chunk_inputs:
|
for i in non_chunk_inputs:
|
||||||
if i not in chunk_info["inputs"] and i not in prepose_nodes:
|
if i not in chunk_info["inputs"]:
|
||||||
chunk_info["inputs_non_chunk"].append(i)
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
|
||||||
return chunk_info
|
return chunk_info
|
||||||
@ -1425,6 +1425,7 @@ class MemoryEstimator(object):
|
|||||||
) / (1024**2)
|
) / (1024**2)
|
||||||
|
|
||||||
# determine chunk ratio for current node
|
# determine chunk ratio for current node
|
||||||
|
# TODO: adapt to prepose node memory
|
||||||
if chunk_within:
|
if chunk_within:
|
||||||
chunk_ratio = self._get_chunk_ratio(
|
chunk_ratio = self._get_chunk_ratio(
|
||||||
node,
|
node,
|
||||||
@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object):
|
|||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
||||||
if len(start_traces) > 1:
|
if len(start_traces) > 1:
|
||||||
# TODO: implement multi input chunk
|
|
||||||
continue
|
continue
|
||||||
for start_node, start_trace in start_traces.items():
|
for start_node, start_trace in start_traces.items():
|
||||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
||||||
@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
|||||||
|
|
||||||
# if a node has a user node which is not in the node list
|
# if a node has a user node which is not in the node list
|
||||||
# we treat that user node as the node receiving the current node output
|
# we treat that user node as the node receiving the current node output
|
||||||
# TODO: it is unsafe to remove non compute node here
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for output_node in node.users.keys():
|
for output_node in node.users.keys():
|
||||||
if (
|
if (
|
||||||
@ -1901,6 +1900,8 @@ def emit_code_with_chunk(
|
|||||||
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
chunk_outputs = [i["outputs"][0] for i in chunk_search]
|
||||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
|
||||||
|
|
||||||
|
chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search]
|
||||||
|
|
||||||
node_idx = 0
|
node_idx = 0
|
||||||
region_idx = 0
|
region_idx = 0
|
||||||
within_chunk_region = False
|
within_chunk_region = False
|
||||||
@ -1911,7 +1912,11 @@ def emit_code_with_chunk(
|
|||||||
if node_idx in chunk_starts:
|
if node_idx in chunk_starts:
|
||||||
within_chunk_region = True
|
within_chunk_region = True
|
||||||
region_idx = chunk_starts.index(node_idx)
|
region_idx = chunk_starts.index(node_idx)
|
||||||
|
# add prepose nodes
|
||||||
|
for i in chunk_prepose_nodes[region_idx]:
|
||||||
|
prepose_node = node_list[_find_idx_by_name(i.name, node_list)]
|
||||||
|
emit_node_func(prepose_node, body)
|
||||||
|
delete_unused_value_func(prepose_node, body, chunk_inputs_names)
|
||||||
# add for loop
|
# add for loop
|
||||||
body.append(
|
body.append(
|
||||||
_gen_loop_start(
|
_gen_loop_start(
|
||||||
@ -1922,20 +1927,22 @@ def emit_code_with_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]):
|
||||||
# replace input var with chunk var
|
pass
|
||||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
else:
|
||||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
emit_node_func(node, body)
|
||||||
if idx == node_idx:
|
# replace input var with chunk var
|
||||||
chunk_slice = _gen_chunk_slice_dim(
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||||
dim, "chunk_idx", _get_node_shape(input_node)
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||||
)
|
if idx == node_idx:
|
||||||
body[-1] = _replace_name(
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
body[-1], input_node.name, input_node.name + chunk_slice
|
dim, "chunk_idx", _get_node_shape(input_node)
|
||||||
)
|
)
|
||||||
body[-1] = " " + body[-1]
|
body[-1] = _replace_name(
|
||||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
|
)
|
||||||
|
body[-1] = " " + body[-1]
|
||||||
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
else:
|
else:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
if node_idx not in chunk_inputs:
|
if node_idx not in chunk_inputs:
|
||||||
|
Loading…
Reference in New Issue
Block a user