mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
code style
This commit is contained in:
@@ -81,7 +81,9 @@ class TraceFlow(object):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_index.node_list[k])
|
||||
inherit_dim = self._find_inherit_dim(
|
||||
input_node, v, self.trace_index.node_list[k]
|
||||
)
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
@@ -217,7 +219,9 @@ class TraceFlow(object):
|
||||
for arg in arg_list:
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(arg.name, self.trace_index.node_list)
|
||||
<= find_idx_by_name(
|
||||
arg.name, self.trace_index.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
@@ -255,7 +259,9 @@ class TraceFlow(object):
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_index._find_source_trace_from_node(user)[chunk_dim]
|
||||
user_source = self.trace_index._find_source_trace_from_node(
|
||||
user
|
||||
)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user