From 8a989a0d89418c308c1d97b4d692a4e753395732 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 6 Jan 2023 17:55:22 +0800 Subject: [PATCH] code style --- colossalai/autochunk/autochunk_codegen.py | 69 +++++++++++++---------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 891753faa..0db2e5908 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict): return context +def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body): + if "ones_like" in node.name: + meta_node = search_chunk.trace_index.node_list[node_idx] + chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"] + if get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_infos[region_idx]["node_chunk_dim"] + or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) + return body + + +def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body): + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim[0], "chunk_idx", get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + return body + + def emit_code_with_chunk( body, nodes, @@ -156,36 +189,14 @@ def emit_code_with_chunk( if within_chunk_region: emit_node_func(node, body) # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim[0], "chunk_idx", get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) + body = _replace_input_var( + chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body + ) # ones like - if "ones_like" in node.name: - meta_node = search_chunk.trace_index.node_list[node_idx] - chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][ - "chunk_dim" - ] - if get_node_shape(meta_node)[chunk_dim] != 1: - source_node = meta_node.args[0].args[0] - if ( - source_node not in chunk_infos[region_idx]["node_chunk_dim"] - or chunk_infos[region_idx]["node_chunk_dim"][source_node][ - "chunk_dim" - ] - is None - ): - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + body = _replace_ones_like( + search_chunk, chunk_infos, region_idx, node_idx, node, body + ) + # reassgin reshape size body[-1] = _replace_reshape_size( body[-1], node.name, chunk_infos[region_idx]["reshape_size"] )