mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
@@ -6,12 +6,7 @@ from torch.fx.node import Node, map_arg
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import (
|
||||
delete_free_var_from_last_use,
|
||||
find_idx_by_name,
|
||||
get_node_shape,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape, is_non_memory_node
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
@@ -240,7 +235,7 @@ class EstimateMemory(object):
|
||||
elif node.op == "output":
|
||||
continue
|
||||
# no change for non compute node
|
||||
elif is_non_compute_node_except_placeholder(node):
|
||||
elif is_non_memory_node(node):
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# node is a compute op
|
||||
# calculate tmp, output node and delete node memory
|
||||
|
@@ -118,16 +118,34 @@ class TraceFlow(object):
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
end_idx: int,
|
||||
cur_node_dim: int,
|
||||
cur_node_compute: Dict,
|
||||
cur_node_source: Dict,
|
||||
cur_node_fix_dim: List,
|
||||
all_node_info: Dict,
|
||||
next_node_list: List,
|
||||
) -> bool:
|
||||
"""
|
||||
Given the current node and one of its arg node,
|
||||
this function finds out arg node's chunk dim and fix dim
|
||||
|
||||
Args:
|
||||
arg_node (Node): input node
|
||||
start_idx (int): chunk region start
|
||||
end_idx (int): chunk region end
|
||||
cur_node_dim (int): current node chunk dim
|
||||
cur_node_compute (Dict): current node compute dict
|
||||
cur_node_source (Dict): current node source dict
|
||||
cur_node_fix_dim (List): current node fix dim
|
||||
all_node_info (Dict): all node chunk info in the chunk region
|
||||
next_node_list (List)
|
||||
|
||||
Returns:
|
||||
bool: True if this node can be added to the flow, vice versa.
|
||||
"""
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
@@ -142,6 +160,9 @@ class TraceFlow(object):
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
# chunk dim should be None if shape size is 1
|
||||
if get_node_shape(arg_node)[arg_dim] == 1:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
@@ -184,7 +205,7 @@ class TraceFlow(object):
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.args:
|
||||
for arg in cur_node.all_input_nodes:
|
||||
if type(arg) != type(cur_node):
|
||||
continue
|
||||
if is_non_compute_node(arg):
|
||||
|
@@ -432,6 +432,38 @@ class TraceIndice(object):
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_cat_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for cat op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
self._add_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_sum_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for sum op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
nodes_in = flat_list(node.args[0])
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
|
||||
for n in nodes_in[1:]:
|
||||
self._mark_computation_from_node(n, node)
|
||||
cat_dim = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, cat_dim)
|
||||
|
||||
def _assign_getitem_indice(self, node: Node, node_idx: int):
|
||||
"""
|
||||
Assign indice for getitem.
|
||||
@@ -442,7 +474,16 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
node_args = flat_list(node.args[1:])
|
||||
if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
|
||||
flag = False
|
||||
for node_arg in node_args:
|
||||
node_arg_str = str(node_arg)
|
||||
if any(i == node_arg_str for i in ["None", "Ellipsis"]):
|
||||
flag = True
|
||||
break
|
||||
if "slice" in node_arg_str:
|
||||
flag = True
|
||||
break
|
||||
if flag == False:
|
||||
return
|
||||
|
||||
# node args should be like [Ellipsis, slice(start, step, end), None]
|
||||
@@ -461,8 +502,11 @@ class TraceIndice(object):
|
||||
shape_gap = len(node_shape) - len(node_args) + 1
|
||||
origin_idx_count += shape_gap
|
||||
new_idx_count += shape_gap
|
||||
# slice(None, None, None) means all indexes, doesn't support other slice
|
||||
elif "slice(None, None, None)" == node_arg_str:
|
||||
# slice(None, None, None) means all indexes
|
||||
elif "slice" in node_arg_str:
|
||||
if "slice(None, None, None)" != node_arg_str:
|
||||
self._del_dim(node_idx, new_idx_count)
|
||||
self._add_dim(node_idx, new_idx_count)
|
||||
origin_idx_count += 1
|
||||
new_idx_count += 1
|
||||
# None means a new dim
|
||||
@@ -565,7 +609,7 @@ class TraceIndice(object):
|
||||
self._assign_view_reshape_indice(node, idx)
|
||||
elif "unsqueeze" in node.name:
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous"]):
|
||||
elif any(i in node.name for i in ["to", "contiguous", "clone"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
@@ -574,6 +618,8 @@ class TraceIndice(object):
|
||||
elif node.op == "call_function":
|
||||
if "linear" in node.name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "cat" in node.name:
|
||||
self._assign_cat_indice(node, idx)
|
||||
elif "matmul" in node.name:
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
@@ -586,6 +632,8 @@ class TraceIndice(object):
|
||||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "sum" in node.name:
|
||||
self._assign_sum_indice(node, idx)
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif "getitem" in node.name:
|
||||
|
@@ -3,10 +3,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def flat_list(inputs):
|
||||
def flat_list(inputs: Any) -> List:
|
||||
"""
|
||||
flat a list by recursion
|
||||
"""
|
||||
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
||||
return [inputs]
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
@@ -16,7 +18,7 @@ def flat_list(inputs):
|
||||
return res
|
||||
|
||||
|
||||
def find_first_tensor_arg(node):
|
||||
def find_first_tensor_arg(node: Node) -> Node:
|
||||
"""
|
||||
Find the first input tensor arg for a node
|
||||
"""
|
||||
@@ -26,7 +28,7 @@ def find_first_tensor_arg(node):
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
def is_non_compute_node(node: Node) -> bool:
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
|
||||
return True
|
||||
if "getitem" in node.name:
|
||||
@@ -34,16 +36,26 @@ def is_non_compute_node(node):
|
||||
for node_arg in node_args:
|
||||
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
||||
return False
|
||||
if "slice" in str(node_arg):
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_node_shape(node):
|
||||
def get_node_shape(node: Node) -> List:
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
|
||||
|
||||
def is_non_memory_node(node: Node) -> bool:
|
||||
if "getitem" in node.name:
|
||||
return True
|
||||
if "output" in node.op:
|
||||
return True
|
||||
return is_non_compute_node(node)
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if "placeholder" in node.op:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user