[autochunk] support diffusion for autochunk (#2621)

* add alphafold benchmark

* renae alphafold test

* rename tests

* rename diffuser

* renme

* rename

* update transformer

* update benchmark

* update benchmark

* update bench memory

* update transformer benchmark

* rename

* support diffuser

* support unet metainfo prop

* fix bug and simplify code

* update linear and support some op

* optimize max region search, support conv

* update unet test

* support some op

* support groupnorm and interpolate

* update flow search

* add fix dim in node flow

* fix utils

* rename

* support diffusion

* update diffuser

* update chunk search

* optimize imports

* import

* finish autochunk
This commit is contained in:
oahzxl
2023-02-07 16:32:45 +08:00
committed by GitHub
parent 291b051171
commit 6ba8364881
6 changed files with 216 additions and 166 deletions

View File

@@ -9,18 +9,7 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
if AUTOCHUNK_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
@@ -143,7 +132,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
return context
def _replace_ones_like(
def _replace_new_tensor_like_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
@@ -154,7 +143,7 @@ def _replace_ones_like(
"""
add chunk slice for new tensor op such as ones like
"""
if "ones_like" in node.name:
if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
@@ -166,6 +155,33 @@ def _replace_ones_like(
return body
def _replace_new_tensor_shape(
search_chunk: SearchChunk,
chunk_infos: List[Dict],
region_idx: int,
node_idx: int,
node: Node,
body: List[str],
) -> List[str]:
"""
add chunk slice for new tensor op such as ones
"""
if get_node_name(node) in ["ones", "zeros", "empty"]:
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if chunk_dim is None:
return
if get_node_shape(meta_node)[chunk_dim] == 1:
return
origin_shape = str(node.args)
new_shape = list(node.args)
new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
new_shape = str(new_shape)
new_shape = new_shape.replace("'", "")
body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
return body
def _add_node_slice(
chunk_nodes: List[Node],
region_idx: int,
@@ -265,8 +281,10 @@ def emit_code_with_chunk(
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
# replace output var with chunk var
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
# ones like
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# new tensor like
body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# new tensor
body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
# reassgin reshape size
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
body[-1] = " " + body[-1]