mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)
This commit is contained in:
@@ -461,7 +461,7 @@ class TraceIndice(object):
|
||||
nodes_in.append(node_in)
|
||||
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
def _assign_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
@@ -792,7 +792,7 @@ class TraceIndice(object):
|
||||
self._add_dim(node_idx, i)
|
||||
dim_from.reverse()
|
||||
|
||||
# inheirt indice from current node
|
||||
# inherit indice from current node
|
||||
if len(dim_from) != 0 and len(dim_to) != 0:
|
||||
if dim_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
@@ -852,7 +852,7 @@ class TraceIndice(object):
|
||||
elif "split" == node_name:
|
||||
self._assign_split_indice(node, idx)
|
||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
self._assign_no_change_indice(node, idx)
|
||||
elif "new_ones" == node_name:
|
||||
self._assign_all_indice(node, idx)
|
||||
elif "flatten" == node_name:
|
||||
@@ -914,7 +914,7 @@ class TraceIndice(object):
|
||||
elif "conv2d" == node_name:
|
||||
self._assign_conv2d_indice(node, idx)
|
||||
elif "identity" == node_name:
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
self._assign_no_change_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user