mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[autochunk] support diffusion for autochunk (#2621)
* add alphafold benchmark * renae alphafold test * rename tests * rename diffuser * renme * rename * update transformer * update benchmark * update benchmark * update bench memory * update transformer benchmark * rename * support diffuser * support unet metainfo prop * fix bug and simplify code * update linear and support some op * optimize max region search, support conv * update unet test * support some op * support groupnorm and interpolate * update flow search * add fix dim in node flow * fix utils * rename * support diffusion * update diffuser * update chunk search * optimize imports * import * finish autochunk
This commit is contained in:
@@ -9,18 +9,7 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
|
||||
AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
|
||||
|
||||
if AUTOCHUNK_AVAILABLE:
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
|
||||
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
@@ -143,7 +132,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
|
||||
return context
|
||||
|
||||
|
||||
def _replace_ones_like(
|
||||
def _replace_new_tensor_like_shape(
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List[Dict],
|
||||
region_idx: int,
|
||||
@@ -154,7 +143,7 @@ def _replace_ones_like(
|
||||
"""
|
||||
add chunk slice for new tensor op such as ones like
|
||||
"""
|
||||
if "ones_like" in node.name:
|
||||
if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
|
||||
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
@@ -166,6 +155,33 @@ def _replace_ones_like(
|
||||
return body
|
||||
|
||||
|
||||
def _replace_new_tensor_shape(
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List[Dict],
|
||||
region_idx: int,
|
||||
node_idx: int,
|
||||
node: Node,
|
||||
body: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for new tensor op such as ones
|
||||
"""
|
||||
if get_node_name(node) in ["ones", "zeros", "empty"]:
|
||||
meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if chunk_dim is None:
|
||||
return
|
||||
if get_node_shape(meta_node)[chunk_dim] == 1:
|
||||
return
|
||||
origin_shape = str(node.args)
|
||||
new_shape = list(node.args)
|
||||
new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
|
||||
new_shape = str(new_shape)
|
||||
new_shape = new_shape.replace("'", "")
|
||||
body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
|
||||
return body
|
||||
|
||||
|
||||
def _add_node_slice(
|
||||
chunk_nodes: List[Node],
|
||||
region_idx: int,
|
||||
@@ -265,8 +281,10 @@ def emit_code_with_chunk(
|
||||
body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
|
||||
# replace output var with chunk var
|
||||
body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
|
||||
# ones like
|
||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# new tensor like
|
||||
body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# new tensor
|
||||
body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# reassgin reshape size
|
||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||
body[-1] = " " + body[-1]
|
||||
|
Reference in New Issue
Block a user