mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -16,7 +16,6 @@ from .utils import (
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.node_mgr = node_mgr
|
||||
@@ -151,7 +150,7 @@ class TraceFlow(object):
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
@@ -266,7 +265,7 @@ class TraceFlow(object):
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: self.node_mgr.find_node_idx(x),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
@@ -328,7 +327,8 @@ class TraceFlow(object):
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
|
||||
self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||
)
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
@@ -371,8 +371,9 @@ class TraceFlow(object):
|
||||
|
||||
return chunk_info
|
||||
|
||||
def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
|
||||
chunk_info: Dict):
|
||||
def _get_other_output_info(
|
||||
self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int, chunk_info: Dict
|
||||
):
|
||||
start_node = self.node_mgr.get_node_by_idx(start_idx)
|
||||
# loop all outputs
|
||||
for output in outputs:
|
||||
@@ -384,8 +385,8 @@ class TraceFlow(object):
|
||||
# skip non tensor
|
||||
if get_node_shape(output) is None:
|
||||
# log shape tensor
|
||||
if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
|
||||
if len(output.meta["fwd_out"]) > 0 and isinstance(output.meta["fwd_out"][0], int):
|
||||
chunk_info["outputs_non_tensor"][output] = str(output.meta["fwd_out"])
|
||||
continue
|
||||
# loop every dim of outputs, try to find a legal one
|
||||
for output_dim in range(len(get_node_shape(output))):
|
||||
@@ -421,7 +422,8 @@ class TraceFlow(object):
|
||||
for k, v in new_all_node_info.items():
|
||||
if k in chunk_info["node_chunk_dim"]:
|
||||
chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
|
||||
set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"])
|
||||
)
|
||||
else:
|
||||
chunk_info["node_chunk_dim"][k] = v
|
||||
chunk_info["outputs"].append(output)
|
||||
@@ -443,8 +445,11 @@ class TraceFlow(object):
|
||||
if node.args[0] in chunk_info["inputs_non_chunk"]:
|
||||
continue
|
||||
reshape_args = flat_list(node.args[1:])
|
||||
if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
|
||||
reshape_args[0].meta['fwd_out']) > 1:
|
||||
if (
|
||||
len(reshape_args) == 1
|
||||
and get_node_shape(reshape_args[0]) is None
|
||||
and len(reshape_args[0].meta["fwd_out"]) > 1
|
||||
):
|
||||
continue
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
new_shape = ""
|
||||
@@ -462,16 +467,17 @@ class TraceFlow(object):
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
|
||||
end_idx: int) -> bool:
|
||||
def check_region_start_end(
|
||||
self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int, end_idx: int
|
||||
) -> bool:
|
||||
"""
|
||||
check if region start and end is legal
|
||||
"""
|
||||
# dim cannot be None
|
||||
if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
|
||||
if get_node_shape(end_node) is None or get_node_shape(start_node) is None:
|
||||
return False
|
||||
# dim size cannot be 1
|
||||
if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
|
||||
if get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1:
|
||||
return False
|
||||
# must have users
|
||||
if len(end_node.users) == 0:
|
||||
|
Reference in New Issue
Block a user