mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[autochunk] support multi outputs chunk search (#2538)
Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy. 1. rewrite search strategy to support multi outputs chunk search 2. fix many, many bugs 3. update tests
This commit is contained in:
@@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
|
||||
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||
@@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
|
||||
return new_shape
|
||||
|
||||
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
"""
|
||||
Generate chunk loop start
|
||||
|
||||
@@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
|
||||
context (str): generated str
|
||||
"""
|
||||
input_node = chunk_input[0]
|
||||
out_shape = get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
|
||||
(out_str, input_node.name, input_node.name, chunk_size))
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
|
||||
context = ""
|
||||
for i in range(len(chunk_output)):
|
||||
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||
if get_node_name(chunk_output[i]) == "split":
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||
input_node.name)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||
tensor_str = "[" + tensor_str[:-2] + "]"
|
||||
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
|
||||
else:
|
||||
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
|
||||
input_node.name, input_node.name)
|
||||
|
||||
out_shape = get_node_shape(chunk_output[0])
|
||||
chunk_shape = out_shape[chunk_ouput_dim[0]]
|
||||
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
chunk_outputs: Node,
|
||||
chunk_outputs_dim: int,
|
||||
node_list: List[Node],
|
||||
) -> str:
|
||||
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
|
||||
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
@@ -102,22 +108,13 @@ def _gen_loop_end(
|
||||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
|
||||
context = "chunk_size = None"
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
|
||||
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
|
||||
context += "\n"
|
||||
return context
|
||||
|
||||
@@ -158,7 +155,7 @@ def _replace_ones_like(
|
||||
add chunk slice for new tensor op such as ones like
|
||||
"""
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
@@ -169,21 +166,37 @@ def _replace_ones_like(
|
||||
return body
|
||||
|
||||
|
||||
def _replace_input_node(
|
||||
chunk_inputs: List[Node],
|
||||
def _add_node_slice(
|
||||
chunk_nodes: List[Node],
|
||||
region_idx: int,
|
||||
chunk_inputs_dim: Dict,
|
||||
chunk_nodes_dim: Dict,
|
||||
node_idx: int,
|
||||
body: List[str],
|
||||
node: Node,
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for input nodes
|
||||
"""
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
|
||||
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
|
||||
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
|
||||
# inputs node
|
||||
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
|
||||
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||
# outputs node
|
||||
else:
|
||||
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||
get_node_shape(chunk_node))
|
||||
if get_node_name(chunk_node) == "split":
|
||||
split_chunk_slice = ""
|
||||
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||
split_chunk_slice = split_chunk_slice[:-2]
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
|
||||
else:
|
||||
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
@@ -222,7 +235,8 @@ def emit_code_with_chunk(
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs = [i["outputs"] for i in chunk_infos]
|
||||
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||
@@ -248,7 +262,9 @@ def emit_code_with_chunk(
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
|
||||
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
||||
# replace output var with chunk var
|
||||
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
||||
# ones like
|
||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# reassgin reshape size
|
||||
@@ -263,13 +279,8 @@ def emit_code_with_chunk(
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
))
|
||||
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
|
||||
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
Reference in New Issue
Block a user