mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -3,14 +3,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import (
|
||||
find_first_tensor_arg,
|
||||
find_idx_by_name,
|
||||
flat_list,
|
||||
get_module_node_name,
|
||||
get_node_name,
|
||||
get_node_shape,
|
||||
)
|
||||
from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
@@ -35,8 +28,8 @@ class TraceIndice(object):
|
||||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List[Node]) -> None:
|
||||
self.node_list = node_list
|
||||
def __init__(self, node_mgr: NodeMgr) -> None:
|
||||
self.node_mgr = node_mgr
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
@@ -45,7 +38,7 @@ class TraceIndice(object):
|
||||
|
||||
def _init_indice_trace_list(self) -> List:
|
||||
indice_trace_list = []
|
||||
for n in self.node_list:
|
||||
for n in self.node_mgr.get_node_list():
|
||||
if get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||
@@ -99,7 +92,7 @@ class TraceIndice(object):
|
||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
||||
node_from_idx = self.node_mgr.find_node_idx(node_from)
|
||||
if init:
|
||||
node_to_trace_source[node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
@@ -200,7 +193,7 @@ class TraceIndice(object):
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
@@ -214,7 +207,7 @@ class TraceIndice(object):
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
@@ -227,7 +220,7 @@ class TraceIndice(object):
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
return self.indice_trace_list[node_idx]["indice"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node: Node) -> List:
|
||||
@@ -239,7 +232,7 @@ class TraceIndice(object):
|
||||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_idx = self.node_mgr.find_node_idx(node)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
|
||||
@@ -454,8 +447,6 @@ class TraceIndice(object):
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
for _ in range(len(get_node_shape(node.args[0]))):
|
||||
self._add_dim(node_idx, 0)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
dim_idx = node.kwargs["dim"]
|
||||
self._del_dim(node_idx, dim_idx)
|
||||
@@ -702,21 +693,20 @@ class TraceIndice(object):
|
||||
if (view_dict["idx_to"] == idx_from and view_dict["dim_to"] == dim_from
|
||||
and view_dict["dim_from"] == dim_to):
|
||||
# inheirt indice from current node
|
||||
for dim_to_i in dim_to:
|
||||
for dim_from_i in dim_from:
|
||||
self._inherit_indice(origin_node, dim_from_i, node, dim_to_i, init=False)
|
||||
if len_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
|
||||
elif origin_shape[dim_from[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
elif len_diff == -1:
|
||||
if target_shape[dim_to[0]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
|
||||
elif target_shape[dim_to[1]] == 1:
|
||||
self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
|
||||
# inherid indice from input node of last view
|
||||
for dim_to_i in dim_to:
|
||||
self._inherit_indice(view_node.args[0], dim_to_i, node, dim_to_i, init=False)
|
||||
|
||||
# inherit computation
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self._mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
"idx_from": [origin_trace[i] for i in dim_from],
|
||||
@@ -742,7 +732,7 @@ class TraceIndice(object):
|
||||
|
||||
active_nodes = self.active_node_list[trace_range[0]:trace_range[1] + 1]
|
||||
active_nodes = set(flat_list(active_nodes))
|
||||
active_nodes = [find_idx_by_name(i, self.node_list) for i in active_nodes]
|
||||
active_nodes = [self.node_mgr.find_node_idx_by_name(i) for i in active_nodes]
|
||||
for i in range(trace_range[0], trace_range[1] + 1):
|
||||
trace = self.indice_trace_list[i]
|
||||
# clear compute
|
||||
@@ -758,7 +748,7 @@ class TraceIndice(object):
|
||||
dim_source.pop(k)
|
||||
|
||||
def trace_indice(self) -> None:
|
||||
for idx, node in enumerate(self.node_list):
|
||||
for idx, node in enumerate(self.node_mgr.get_node_list()):
|
||||
node_name = get_node_name(node)
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_indice(node, idx)
|
||||
|
Reference in New Issue
Block a user