mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-09 15:35:54 +00:00
finish node reorder
This commit is contained in:
parent
884a228ea6
commit
51ef8384c1
@ -1238,7 +1238,7 @@ class MemoryEstimator(object):
|
|||||||
|
|
||||||
def estimate_chunk_inference_mem(
|
def estimate_chunk_inference_mem(
|
||||||
self,
|
self,
|
||||||
gm: torch.fx.GraphModule,
|
node_list,
|
||||||
chunk_infos=None,
|
chunk_infos=None,
|
||||||
):
|
):
|
||||||
act_memory = 0.0
|
act_memory = 0.0
|
||||||
@ -1247,7 +1247,6 @@ class MemoryEstimator(object):
|
|||||||
active_node_list = []
|
active_node_list = []
|
||||||
active_node_list_log = []
|
active_node_list_log = []
|
||||||
not_contiguous_list = []
|
not_contiguous_list = []
|
||||||
node_list = list(gm.graph.nodes)
|
|
||||||
user_to_last_uses = self._get_last_usr(node_list)
|
user_to_last_uses = self._get_last_usr(node_list)
|
||||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||||
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
_delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||||
@ -1281,7 +1280,6 @@ class MemoryEstimator(object):
|
|||||||
) / (1024**2)
|
) / (1024**2)
|
||||||
|
|
||||||
# determine chunk ratio for current node
|
# determine chunk ratio for current node
|
||||||
# TODO: adapt to prepose node memory
|
|
||||||
if chunk_within:
|
if chunk_within:
|
||||||
chunk_ratio = self._get_chunk_ratio(
|
chunk_ratio = self._get_chunk_ratio(
|
||||||
node,
|
node,
|
||||||
@ -1371,10 +1369,7 @@ class MemoryEstimator(object):
|
|||||||
class ChunkRegionSearch(object):
|
class ChunkRegionSearch(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.node_list = list(gm.graph.nodes)
|
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||||
self.index_tracer = IndexTracer(
|
|
||||||
self.node_list
|
|
||||||
) # node list shared in index tracer
|
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||||
|
|
||||||
@ -1385,7 +1380,7 @@ class ChunkRegionSearch(object):
|
|||||||
|
|
||||||
def _get_free_var(self):
|
def _get_free_var(self):
|
||||||
free_var_idx = []
|
free_var_idx = []
|
||||||
for idx, n in enumerate(self.node_list):
|
for idx, n in enumerate(self.index_tracer.node_list):
|
||||||
if n.op == "placeholder":
|
if n.op == "placeholder":
|
||||||
free_var_idx.append(idx)
|
free_var_idx.append(idx)
|
||||||
return free_var_idx
|
return free_var_idx
|
||||||
@ -1455,13 +1450,13 @@ class ChunkRegionSearch(object):
|
|||||||
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
def _find_free_dim(self, input_trace, output_trace, start_idx, end_idx):
|
||||||
start_traces = input_trace[start_idx]
|
start_traces = input_trace[start_idx]
|
||||||
end_trace = output_trace[end_idx]
|
end_trace = output_trace[end_idx]
|
||||||
end_node = self.node_list[end_idx]
|
end_node = self.index_tracer.node_list[end_idx]
|
||||||
chunk_infos = []
|
chunk_infos = []
|
||||||
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
|
for end_dim, _ in enumerate(end_trace["idx"]):
|
||||||
if len(start_traces) > 1:
|
if len(start_traces) > 1:
|
||||||
continue
|
continue
|
||||||
for start_node, start_trace in start_traces.items():
|
for start_node, start_trace in start_traces.items():
|
||||||
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
|
for start_dim, _ in enumerate(start_trace["idx"]):
|
||||||
# dim size cannot be 1
|
# dim size cannot be 1
|
||||||
if (
|
if (
|
||||||
_get_node_shape(end_node)[end_dim] == 1
|
_get_node_shape(end_node)[end_dim] == 1
|
||||||
@ -1494,7 +1489,7 @@ class ChunkRegionSearch(object):
|
|||||||
possible_chunk_region = []
|
possible_chunk_region = []
|
||||||
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
output_trace = copy.deepcopy(self.index_tracer.idx_trace_list)
|
||||||
input_trace = [] # trace of a node's input nodes
|
input_trace = [] # trace of a node's input nodes
|
||||||
for _, n in enumerate(self.node_list):
|
for _, n in enumerate(self.index_tracer.node_list):
|
||||||
cur_trace = {}
|
cur_trace = {}
|
||||||
for arg in n.args:
|
for arg in n.args:
|
||||||
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(
|
if type(arg) == type(n) and not _is_non_compute_node_except_placeholder(
|
||||||
@ -1507,8 +1502,8 @@ class ChunkRegionSearch(object):
|
|||||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
# skip non compute nodes
|
# skip non compute nodes
|
||||||
if _is_non_compute_node(
|
if _is_non_compute_node(
|
||||||
self.node_list[start_idx]
|
self.index_tracer.node_list[start_idx]
|
||||||
) or _is_non_compute_node(self.node_list[end_idx]):
|
) or _is_non_compute_node(self.index_tracer.node_list[end_idx]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# select free dim
|
# select free dim
|
||||||
@ -1577,7 +1572,9 @@ class ChunkRegionSearch(object):
|
|||||||
init_mem_peak,
|
init_mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm)
|
) = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
|
self.index_tracer.node_list
|
||||||
|
)
|
||||||
mem_peak = init_mem_peak
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -1590,7 +1587,9 @@ class ChunkRegionSearch(object):
|
|||||||
mem_peak,
|
mem_peak,
|
||||||
_,
|
_,
|
||||||
active_node,
|
active_node,
|
||||||
) = self.memory_estimator.estimate_chunk_inference_mem(self.gm, chunk_infos)
|
) = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
|
self.index_tracer.node_list, chunk_infos
|
||||||
|
)
|
||||||
if self._stop_search(init_mem_peak, mem_peak):
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
break
|
break
|
||||||
return chunk_infos
|
return chunk_infos
|
||||||
|
Loading…
Reference in New Issue
Block a user