mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it. 1. support some evoformer's op in fx 2. support evoformer test 3. add repos for test code
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_idx_by_name, get_node_shape
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
@@ -79,9 +79,7 @@ class TraceIndice(object):
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
||||
node_from_trace["compute"][node_from_dim]
|
||||
)
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
@@ -209,7 +207,7 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = node.args[0]
|
||||
input_node = find_first_tensor_arg(node)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||
|
||||
@@ -227,6 +225,8 @@ class TraceIndice(object):
|
||||
node_idx (int)
|
||||
"""
|
||||
shape = node.meta["tensor_meta"].shape
|
||||
if shape is None:
|
||||
return
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self._add_indice())
|
||||
@@ -259,7 +259,7 @@ class TraceIndice(object):
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
permute_dim = unflat_list(node.args[1:])
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
@@ -359,6 +359,15 @@ class TraceIndice(object):
|
||||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
if '...' in right:
|
||||
replace_list = "!@#$%^&*"
|
||||
target_len = len(get_node_shape(node))
|
||||
add_len = target_len - len(right) + 3
|
||||
replace_str = replace_list[:add_len]
|
||||
right = right.replace("...", replace_str)
|
||||
for ll in range(len(left)):
|
||||
left[ll] = left[ll].replace("...", replace_str)
|
||||
|
||||
all_index = []
|
||||
for i in left:
|
||||
for c in i:
|
||||
@@ -369,9 +378,7 @@ class TraceIndice(object):
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if right_indice in left_str:
|
||||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_indice(
|
||||
input_nodes[left_idx], source_idx, node, right_idx
|
||||
)
|
||||
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
|
||||
|
||||
def _assign_softmax_indice(self, node, idx):
|
||||
"""
|
||||
@@ -440,11 +447,12 @@ class TraceIndice(object):
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
for i in range(1, len(node.args)):
|
||||
if isinstance(node.args[i], int):
|
||||
target_shape.append(node.args[i])
|
||||
unflated_args = unflat_list(node.args)
|
||||
for i in range(1, len(unflated_args)):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
else:
|
||||
target_shape.append(node.args[i].meta["fwd_out"][0])
|
||||
target_shape.append(unflated_args[i].meta["fwd_out"][0])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
@@ -472,13 +480,7 @@ class TraceIndice(object):
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"shape"
|
||||
+ str(origin_shape)
|
||||
+ "and"
|
||||
+ str(target_shape)
|
||||
+ "view not implemented"
|
||||
)
|
||||
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
@@ -521,6 +523,8 @@ class TraceIndice(object):
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
@@ -530,7 +534,7 @@ class TraceIndice(object):
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
@@ -538,21 +542,21 @@ class TraceIndice(object):
|
||||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "getattr" in node.name:
|
||||
continue # get attr like shape
|
||||
elif "getitem" in node.name:
|
||||
continue # get item in list
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
node.name, "function not implemented yet!"
|
||||
)
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == "call_module":
|
||||
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
continue
|
||||
else:
|
||||
|
Reference in New Issue
Block a user