[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)

This commit is contained in:
digger yu
2023-05-19 13:50:00 +08:00
committed by GitHub
parent 21e29e2212
commit 32f81f14d4
6 changed files with 12 additions and 12 deletions

View File

@@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx):
def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
@@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i)
dim_from.reverse()
# inheirt indice from current node
# inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
@@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
@@ -914,7 +914,7 @@ class TraceIndice(object):
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else: