[autochunk] support multi outputs chunk search (#2538)

Support multi outputs chunk search. Previously we only support single output chunk search. It is more flexible and improve performance by a large margin. For transformer, we reduce memory by 40% than previous search strategy.

1. rewrite search strategy to support multi outputs chunk search
2. fix many, many bugs
3. update tests
This commit is contained in:
oahzxl
2023-02-01 13:18:51 +08:00
committed by GitHub
parent f477a14f4a
commit 05671fcb42
14 changed files with 428 additions and 258 deletions

View File

@@ -25,7 +25,7 @@ if AUTOCHUNK_AVAILABLE:
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_logger, get_node_shape
from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
@@ -51,7 +51,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return new_shape
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
"""
Generate chunk loop start
@@ -70,22 +70,28 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim
context (str): generated str
"""
input_node = chunk_input[0]
out_shape = get_node_shape(chunk_output)
out_str = str(list(out_shape))
context = (
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
(out_str, input_node.name, input_node.name, chunk_size))
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
context = ""
for i in range(len(chunk_output)):
shape_str = str(list(get_node_shape(chunk_output[i])))
if get_node_name(chunk_output[i]) == "split":
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
input_node.name)
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
tensor_str = "[" + tensor_str[:-2] + "]"
context += "%s = %s; " % (chunk_output[i].name, tensor_str)
else:
context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
input_node.name, input_node.name)
out_shape = get_node_shape(chunk_output[0])
chunk_shape = out_shape[chunk_ouput_dim[0]]
context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
return context
def _gen_loop_end(
chunk_inputs: List[Node],
chunk_non_compute_inputs: List[Node],
chunk_outputs: Node,
chunk_outputs_dim: int,
node_list: List[Node],
) -> str:
def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
"""
Generate chunk loop end
@@ -102,22 +108,13 @@ def _gen_loop_end(
Returns:
context (str): generated str
"""
chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
context = " chunk_result%s = %s; %s = None\n" % (
chunk_slice,
chunk_outputs_name,
chunk_outputs_name,
)
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
context = "chunk_size = None"
# determine if its the last use for chunk input
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
context += "; %s = None" % chunk_input.name
for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
context += "\n"
return context
@@ -158,7 +155,7 @@ def _replace_ones_like(
add chunk slice for new tensor op such as ones like
"""
if "ones_like" in node.name:
meta_node = search_chunk.trace_indice.node_list[node_idx]
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
@@ -169,21 +166,37 @@ def _replace_ones_like(
return body
def _replace_input_node(
chunk_inputs: List[Node],
def _add_node_slice(
chunk_nodes: List[Node],
region_idx: int,
chunk_inputs_dim: Dict,
chunk_nodes_dim: Dict,
node_idx: int,
body: List[str],
node: Node,
) -> List[str]:
"""
add chunk slice for input nodes
"""
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
# inputs node
if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
# outputs node
else:
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
get_node_shape(chunk_node))
if get_node_name(chunk_node) == "split":
split_chunk_slice = ""
for i in range(len(chunk_node.meta['tensor_meta'])):
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
split_chunk_slice = split_chunk_slice[:-2]
body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
else:
body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
return body
@@ -222,7 +235,8 @@ def emit_code_with_chunk(
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
# chunk outputs
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
chunk_outputs = [i["outputs"] for i in chunk_infos]
chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
@@ -248,7 +262,9 @@ def emit_code_with_chunk(
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size
@@ -263,13 +279,8 @@ def emit_code_with_chunk(
# generate chunk region end
if node_idx in chunk_ends:
body.append(
_gen_loop_end(
chunk_inputs[region_idx],
chunk_inputs_non_chunk[region_idx],
chunk_outputs[region_idx],
chunk_outputs_dim[region_idx],
node_list,
))
_gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
within_chunk_region = False
node_idx += 1