mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 21:09:18 +00:00
[autochunk] support vit (#3084)
support vit for autochunk * support some new ops for vit * fix some bugs * add test for vit
This commit is contained in:
@@ -63,7 +63,7 @@ def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_oup
|
||||
context = ""
|
||||
for i in range(len(chunk_output)):
|
||||
shape_str = str(list(get_node_shape(chunk_output[i])))
|
||||
if get_node_name(chunk_output[i]) == "split":
|
||||
if get_node_name(chunk_output[i]) in ["split", "unbind"]:
|
||||
tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
|
||||
input_node.name)
|
||||
tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
|
||||
@@ -205,7 +205,7 @@ def _add_node_slice(
|
||||
if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
|
||||
get_node_shape(chunk_node))
|
||||
if get_node_name(chunk_node) == "split":
|
||||
if get_node_name(chunk_node) in ["split", "unbind"]:
|
||||
split_chunk_slice = ""
|
||||
for i in range(len(chunk_node.meta['tensor_meta'])):
|
||||
split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
|
||||
|
Reference in New Issue
Block a user