mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autochunk] support transformer (#2526)
This commit is contained in:
@@ -3,9 +3,12 @@ from typing import Any, Dict, Iterable, List, Tuple
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
@@ -272,7 +275,7 @@ def emit_code_with_chunk(
|
||||
node_idx += 1
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
|
@@ -8,7 +8,13 @@ from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
|
||||
from .utils import (
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
get_logger,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
@@ -114,6 +120,12 @@ class SearchChunk(object):
|
||||
chunk_region_start (int)
|
||||
chunk_region_end (int)
|
||||
"""
|
||||
# check if peak node already in chunkinfo
|
||||
if chunk_regions is not None:
|
||||
for i in chunk_regions:
|
||||
if i["region"][0] < peak_node_idx <= i["region"][1]:
|
||||
return None
|
||||
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
@@ -152,55 +164,6 @@ class SearchChunk(object):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
if len(start_traces) > 1:
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
continue
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
@@ -228,9 +191,8 @@ class SearchChunk(object):
|
||||
if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
|
||||
self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
chunk_info = self.trace_flow.find_chunk_info(input_trace, output_trace, start_idx, end_idx)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
@@ -5,6 +5,7 @@ from .utils import is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_indice: TraceIndice,
|
||||
@@ -17,13 +18,11 @@ class SelectChunk(object):
|
||||
self.reorder_graph = reorder_graph
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
self.max_memory = max_memory # MB
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
def _select_best_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak):
|
||||
if self.stratge == "min_memory":
|
||||
best_region = self._select_min_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
@@ -44,9 +43,8 @@ class SelectChunk(object):
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_fit_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
# stop chunk if max memory satisfy memory limit
|
||||
if max(mem_peak) < self.max_memory:
|
||||
return None
|
||||
@@ -63,33 +61,26 @@ class SelectChunk(object):
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
regions_dict.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
# no region found
|
||||
if len(regions_dict) == 0:
|
||||
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||
@@ -113,20 +104,13 @@ class SelectChunk(object):
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[
|
||||
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
||||
+ 1
|
||||
]
|
||||
)
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
|
||||
chunk_infos)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||
@@ -139,12 +123,9 @@ class SelectChunk(object):
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
)
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
|
||||
cur_chunk_infos)[0]
|
||||
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
right = mid - gap
|
||||
else:
|
||||
@@ -153,14 +134,13 @@ class SelectChunk(object):
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.trace_indice.node_list[start : end + 1]:
|
||||
for i in self.trace_indice.node_list[start:end + 1]:
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _select_min_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region,
|
||||
mem_peak):
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
@@ -173,37 +153,31 @@ class SelectChunk(object):
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
# get max possible chunk region
|
||||
max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
|
||||
max([i["region"][1] for i in possible_chunk_regions]))
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
regions_dict_list = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.trace_indice.node_list, cur_region)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
regions_dict_list.append({
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
})
|
||||
|
||||
# select the min mem
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
|
||||
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||
best_region = regions_dict_list[best_region_idx]["chunk_info"]
|
||||
if best_region is not None:
|
||||
best_region["chunk_size"] = 1
|
||||
return best_region
|
||||
@@ -216,9 +190,7 @@ class SelectChunk(object):
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
|
||||
(chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||
return False
|
||||
return True
|
||||
|
@@ -8,9 +8,9 @@ from .utils import (
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,43 +79,6 @@ class TraceFlow(object):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||
duplicate_flag = True
|
||||
duplicate_dim.append((k, v))
|
||||
duplicate_dims.append(duplicate_dim)
|
||||
if duplicate_flag:
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
if return_dim:
|
||||
return False, duplicate_dims
|
||||
else:
|
||||
return False
|
||||
if return_dim:
|
||||
return True, None
|
||||
else:
|
||||
return True
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
@@ -225,9 +188,12 @@ class TraceFlow(object):
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
|
||||
if len(arg_list) >= 2:
|
||||
# need to mark fix dim
|
||||
if any(i == get_node_name(cur_node) for i in ["add", "mul", "truediv", "sub", "where"]):
|
||||
for arg in arg_list:
|
||||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
@@ -240,9 +206,8 @@ class TraceFlow(object):
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif "einsum" in cur_node.name:
|
||||
pass
|
||||
elif "matmul" in cur_node.name:
|
||||
elif any(i == get_node_name(cur_node)
|
||||
for i in ["einsum", "matmul", "view", "to", "getitem", "tensor", "type"]):
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
@@ -426,7 +391,7 @@ class TraceFlow(object):
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
if any(i == get_node_name(node) for i in ["reshape", "view"]):
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
@@ -443,3 +408,62 @@ class TraceFlow(object):
|
||||
reshape_size[node.name] = [origin_shape, new_shape]
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
if len(start_traces) > 1: # TODO need to be removed
|
||||
return []
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
if not self._check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim, end_idx):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.flow_search(start_idx, start_dim, end_idx, end_dim)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
# dim cannot be None
|
||||
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
|
||||
return False
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
return False
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
return False
|
||||
# check index source align
|
||||
if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
|
||||
return False
|
||||
# check index copmute
|
||||
if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
|
||||
return False
|
||||
return True
|
||||
|
@@ -3,7 +3,14 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
|
||||
from .utils import (
|
||||
find_first_tensor_arg,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_module_node_name,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
)
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
@@ -36,7 +43,7 @@ class TraceIndice(object):
|
||||
self.trace_range = []
|
||||
self.active_node_list = []
|
||||
|
||||
def _init_indice_trace_list(self):
|
||||
def _init_indice_trace_list(self) -> List:
|
||||
indice_trace_list = []
|
||||
for n in self.node_list:
|
||||
if get_node_shape(n) != None:
|
||||
@@ -54,7 +61,7 @@ class TraceIndice(object):
|
||||
self.trace_range = trace_range
|
||||
self.active_node_list = active_node_list
|
||||
|
||||
def _add_indice(self):
|
||||
def _add_indice(self) -> int:
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
||||
@@ -64,39 +71,30 @@ class TraceIndice(object):
|
||||
self.indice_count += 1
|
||||
return self.indice_count
|
||||
|
||||
def _del_dim(self, idx, dim_idx):
|
||||
def _del_dim(self, idx: int, dim_idx: int) -> None:
|
||||
"""
|
||||
delete a dim for indice, compute and source
|
||||
"""
|
||||
self.indice_trace_list[idx]["indice"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
||||
|
||||
def _add_dim(self, node_idx, dim_idx):
|
||||
def _add_dim(self, node_idx: int, dim_idx: int) -> None:
|
||||
"""
|
||||
add a dim for indice, compute and source
|
||||
"""
|
||||
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
||||
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||
|
||||
def _transform_indice(self, node, node_dim):
|
||||
node_idx = self._find_indice_trace_from_node(node)
|
||||
dims = list(range(len(node_idx)))
|
||||
return dims[node_dim]
|
||||
|
||||
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(len(node_from_compute)):
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||
def _add_source(
|
||||
self,
|
||||
node_from: Node,
|
||||
node_from_dim: int,
|
||||
node_to: Node,
|
||||
node_to_dim: int,
|
||||
init=False,
|
||||
) -> None:
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
@@ -119,7 +117,50 @@ class TraceIndice(object):
|
||||
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
def _transform_indice(self, node: Node, node_dim: int) -> int:
|
||||
node_idx = self._find_indice_trace_from_node(node)
|
||||
dims = list(range(len(node_idx)))
|
||||
return dims[node_dim]
|
||||
|
||||
def _inherit_indice(
|
||||
self,
|
||||
node_from: Node,
|
||||
node_from_dim: int,
|
||||
node_to: Node,
|
||||
node_to_dim: int,
|
||||
init: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
|
||||
"""
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
if init:
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
|
||||
else:
|
||||
for j in node_from_trace["compute"][node_from_dim]:
|
||||
if j not in node_to_trace["compute"][node_to_dim]:
|
||||
node_to_trace["compute"][node_to_dim].append(j)
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init)
|
||||
|
||||
def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None:
|
||||
"""
|
||||
inherit all dims with init
|
||||
"""
|
||||
# find indice just for assert length
|
||||
node_from_indice = self._find_indice_trace_from_node(node_from)
|
||||
node_to_indice = self._find_indice_trace_from_node(node_to)
|
||||
assert len(node_from_indice) == len(node_to_indice)
|
||||
for i in range(len(node_from_indice)):
|
||||
self._inherit_indice(node_from, i, node_to, i, init=True)
|
||||
|
||||
def _inherit_more_indice_from_node(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
|
||||
"""
|
||||
inheirt indice from node without init
|
||||
"""
|
||||
if exclude == None:
|
||||
exclude = []
|
||||
else:
|
||||
@@ -130,12 +171,9 @@ class TraceIndice(object):
|
||||
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
|
||||
if self._transform_indice(node_to, i) in exclude:
|
||||
continue
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
for j in node_from_compute[i]:
|
||||
if j not in node_to_compute[i]:
|
||||
node_to_compute[i].append(j)
|
||||
self._inherit_indice(node_from, i, node_to, i, init=False)
|
||||
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
def _mark_computation(self, node: Node, idx: int, dim: int) -> None:
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
||||
@@ -152,7 +190,7 @@ class TraceIndice(object):
|
||||
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
|
||||
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
|
||||
|
||||
def _find_trace_from_node(self, node):
|
||||
def _find_trace_from_node(self, node: Node) -> Dict:
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
||||
@@ -166,7 +204,7 @@ class TraceIndice(object):
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
def _find_source_trace_from_node(self, node):
|
||||
def _find_source_trace_from_node(self, node: Node) -> List:
|
||||
"""
|
||||
Find node source trace by the node.
|
||||
|
||||
@@ -180,7 +218,7 @@ class TraceIndice(object):
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
def _find_indice_trace_from_node(self, node):
|
||||
def _find_indice_trace_from_node(self, node) -> List:
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
||||
@@ -192,7 +230,7 @@ class TraceIndice(object):
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["indice"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node):
|
||||
def _find_compute_trace_from_node(self, node: Node) -> List:
|
||||
"""
|
||||
Find node compute trace by the node.
|
||||
|
||||
@@ -204,7 +242,7 @@ class TraceIndice(object):
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
@@ -214,15 +252,9 @@ class TraceIndice(object):
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = find_first_tensor_arg(node)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||
self._inherit_all_indice(input_node, node)
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
|
||||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_indice(self, node: Node, node_idx: int):
|
||||
def _assign_all_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Add new indice for all node's dims.
|
||||
|
||||
@@ -238,7 +270,7 @@ class TraceIndice(object):
|
||||
new_trace.append(self._add_indice())
|
||||
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||
|
||||
def _assign_transpose_indice(self, node: Node, node_idx: int):
|
||||
def _assign_transpose_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
@@ -255,7 +287,7 @@ class TraceIndice(object):
|
||||
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_indice(self, node: Node, node_idx: int):
|
||||
def _assign_permute_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
@@ -272,7 +304,7 @@ class TraceIndice(object):
|
||||
for idx, d in enumerate(permute_dim):
|
||||
self._inherit_indice(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_indice(self, node: Node, node_idx: int):
|
||||
def _assign_linear_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for linear op.
|
||||
1. copy trace from input node and change last indice accroding to weight
|
||||
@@ -293,7 +325,23 @@ class TraceIndice(object):
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int):
|
||||
def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for addmm op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
bias, input_node, weight = node.args
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
self._inherit_indice(bias, -1, node, -1)
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||
@@ -310,7 +358,7 @@ class TraceIndice(object):
|
||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
|
||||
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||
self._inherit_more_indice_from_node(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_layernorm_indice(self, node, idx):
|
||||
@@ -341,14 +389,13 @@ class TraceIndice(object):
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
nodes_in.append(node_in)
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
assert len(nodes_in) <= 2
|
||||
self._inherit_more_indice_from_node(node_in, node)
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
self._inherit_more_indice_from_node(node_in, node)
|
||||
|
||||
def _assign_einsum_indice(self, node, idx):
|
||||
"""
|
||||
@@ -365,7 +412,7 @@ class TraceIndice(object):
|
||||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
if '...' in right:
|
||||
if "..." in right:
|
||||
replace_list = "!@#$%^&*"
|
||||
target_len = len(get_node_shape(node))
|
||||
add_len = target_len - len(right) + 3
|
||||
@@ -399,7 +446,22 @@ class TraceIndice(object):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||
|
||||
def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
|
||||
def _assign_split_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for split op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
for _ in range(len(get_node_shape(node.args[0]))):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
dim_idx = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, dim_idx)
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
@@ -416,18 +478,7 @@ class TraceIndice(object):
|
||||
dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
|
||||
self._add_dim(node_idx, dim_idx)
|
||||
|
||||
def _assign_dropout_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int):
|
||||
def _assign_ones_like_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
@@ -438,7 +489,7 @@ class TraceIndice(object):
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_cat_indice(self, node: Node, node_idx: int):
|
||||
def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for cat op.
|
||||
|
||||
@@ -449,12 +500,12 @@ class TraceIndice(object):
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
self._inherit_more_indice_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
self._add_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_sum_indice(self, node: Node, node_idx: int):
|
||||
def _assign_sum_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for sum op.
|
||||
|
||||
@@ -466,11 +517,46 @@ class TraceIndice(object):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
self._inherit_more_indice_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||
def _assign_arange_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for arange op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_tensor_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for tensor op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if len(get_node_shape(node)) == 0:
|
||||
return
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for embedding op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._add_dim(node_idx, -1)
|
||||
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for getitem.
|
||||
getitem can act like slice sometimes
|
||||
@@ -480,6 +566,19 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
node_args = flat_list(node.args[1:])
|
||||
|
||||
# deal with split
|
||||
if get_node_name(node.args[0]) == "split":
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._del_dim(node_idx, node.args[0].kwargs["dim"])
|
||||
self._add_dim(node_idx, node.args[0].kwargs["dim"])
|
||||
return
|
||||
|
||||
# skip non tensor
|
||||
if get_node_shape(node) is None:
|
||||
return
|
||||
|
||||
# find if slice
|
||||
flag = False
|
||||
for node_arg in node_args:
|
||||
node_arg_str = str(node_arg)
|
||||
@@ -528,7 +627,7 @@ class TraceIndice(object):
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _assign_view_reshape_indice(self, node: Node, node_idx: int):
|
||||
def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
|
||||
"""
|
||||
Assign indice for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
@@ -536,7 +635,7 @@ class TraceIndice(object):
|
||||
3. determine changed dim, and assgin indice for generated dim.
|
||||
4. log changed dim and generated dim for restore
|
||||
5. inherit computation.
|
||||
6. TODO: look into view list to see whether the view is associated with other,
|
||||
6. look into view list to see whether the view is associated with other,
|
||||
if so assgin equal dim according to previous view.
|
||||
|
||||
Args:
|
||||
@@ -552,7 +651,7 @@ class TraceIndice(object):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
else:
|
||||
target_shape.append(unflated_args[i].meta["fwd_out"][0])
|
||||
target_shape.extend(unflated_args[i].meta["fwd_out"])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
@@ -579,17 +678,36 @@ class TraceIndice(object):
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
elif len_diff == 0:
|
||||
# dim equal
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = []
|
||||
dim_to = []
|
||||
else:
|
||||
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
self._assign_indice_as_input(node, node_idx, origin_node)
|
||||
idx_from = [origin_trace[i] for i in dim_from]
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
self._del_dim(node_idx, i)
|
||||
for i in dim_to:
|
||||
self._add_dim(node_idx, i)
|
||||
dim_from.reverse()
|
||||
|
||||
# search view list
|
||||
for view_node, view_dict in self.indice_view_list.items():
|
||||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
and view_dict["dim_from"] == dim_to):
|
||||
# inheirt indice from current node
|
||||
for dim_to_i in dim_to:
|
||||
for dim_from_i in dim_from:
|
||||
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
|
||||
# inherid indice from input node of last view
|
||||
for dim_to_i in dim_to:
|
||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inherit computation
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
@@ -630,7 +748,7 @@ class TraceIndice(object):
|
||||
# clear compute
|
||||
for dim_compute in trace["compute"]:
|
||||
for i in range(len(dim_compute) - 1, -1, -1):
|
||||
if dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes:
|
||||
if (dim_compute[i] < trace_range[0] and dim_compute[i] not in active_nodes):
|
||||
dim_compute.pop(i)
|
||||
continue
|
||||
# clear source
|
||||
@@ -639,59 +757,82 @@ class TraceIndice(object):
|
||||
if k < trace_range[0] and k not in active_nodes:
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self):
|
||||
def trace_indice(self) -> None:
|
||||
for idx, node in enumerate(self.node_list):
|
||||
node_name = get_node_name(node)
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_indice(node, idx)
|
||||
elif node.op == "call_method":
|
||||
if "transpose" in node.name:
|
||||
if "transpose" == node_name:
|
||||
self._assign_transpose_indice(node, idx)
|
||||
elif "permute" in node.name:
|
||||
elif "permute" == node_name:
|
||||
self._assign_permute_indice(node, idx)
|
||||
elif "view" in node.name or "reshape" in node.name:
|
||||
elif "view" == node_name or "reshape" == node_name:
|
||||
self._assign_view_reshape_indice(node, idx)
|
||||
elif "unsqueeze" in node.name:
|
||||
elif "unsqueeze" == node_name:
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
|
||||
elif "split" == node_name:
|
||||
self._assign_split_indice(node, idx)
|
||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" in node.name:
|
||||
elif "new_ones" == node_name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
if "linear" in node.name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "cat" in node.name:
|
||||
self._assign_cat_indice(node, idx)
|
||||
elif "matmul" in node.name:
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
elif "dropout" in node.name:
|
||||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "sum" in node.name:
|
||||
self._assign_sum_indice(node, idx)
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "getitem" in node.name:
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
elif any(i == node_name for i in ["size"]):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == "call_module":
|
||||
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||
raise NotImplementedError(node_name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
if "linear" == node_name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "cat" == node_name:
|
||||
self._assign_cat_indice(node, idx)
|
||||
elif "matmul" == node_name:
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" == node_name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n == node_name for n in [
|
||||
"mul",
|
||||
"add",
|
||||
"sigmoid",
|
||||
"relu",
|
||||
"sub",
|
||||
"truediv",
|
||||
"pow",
|
||||
"dropout",
|
||||
"where",
|
||||
"tanh",
|
||||
]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" == node_name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
elif "einsum" == node_name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "sum" == node_name:
|
||||
self._assign_sum_indice(node, idx)
|
||||
elif "layer_norm" == node_name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
|
||||
elif "getitem" == node_name:
|
||||
self._assign_getitem_indice(node, idx)
|
||||
elif "addmm" == node_name:
|
||||
self._assign_addmm_indice(node, idx)
|
||||
elif "arange" == node_name:
|
||||
self._assign_arange_indice(node, idx)
|
||||
elif "tensor" == node_name:
|
||||
self._assign_arange_indice(node, idx)
|
||||
elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node_name, "function not implemented yet!")
|
||||
elif node.op == "call_module":
|
||||
node_name = get_module_node_name(node)
|
||||
if "layernorm" == node_name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "embedding" == node_name:
|
||||
self._assign_embedding_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
raise NotImplementedError(node_name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
|
@@ -1,13 +1,15 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
|
||||
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
def get_logger():
|
||||
def get_logger() -> Any:
|
||||
return logger
|
||||
|
||||
|
||||
@@ -37,7 +39,7 @@ def find_first_tensor_arg(node: Node) -> Node:
|
||||
|
||||
|
||||
def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
node_args = flat_list(node.args[1:])
|
||||
@@ -64,33 +66,33 @@ def is_non_memory_node(node: Node) -> bool:
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
def is_non_compute_node_except_placeholder(node: Node) -> bool:
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
||||
if "output" in node.op:
|
||||
return False
|
||||
return is_non_compute_node_except_placeholder(node)
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
def find_idx_by_name(name: str, nodes_list: List) -> int:
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses):
|
||||
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == "placeholder":
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
@@ -104,7 +106,7 @@ def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
return input_nodes
|
||||
|
||||
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
@@ -130,3 +132,33 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
||||
|
||||
def get_module_node_name(node: Node) -> str:
|
||||
"""
|
||||
get module class name
|
||||
"""
|
||||
node_targets = node.target.split(".")
|
||||
module = node.graph.owning_module
|
||||
for i in node_targets:
|
||||
module = getattr(module, i)
|
||||
module_name = str(module.__class__).split(".")[-1][:-2]
|
||||
module_name = module_name.lower()
|
||||
return module_name
|
||||
|
||||
|
||||
def get_node_name(node: Node) -> str:
|
||||
"""
|
||||
get node name
|
||||
"""
|
||||
node_name = node.name
|
||||
if "_" in node_name:
|
||||
for i in range(len(node_name) - 1, -1, -1):
|
||||
if node_name[i] == "_":
|
||||
node_name = node_name[:i]
|
||||
break
|
||||
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
return node_name
|
||||
|
Reference in New Issue
Block a user