mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-27 03:26:41 +00:00
[graph] improve the graph building. (#1157)
This commit is contained in:
50
colossalai/nn/graph/utils.py
Normal file
50
colossalai/nn/graph/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import functools
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
from typing import Callable, List
|
||||
from colossalai.nn._ops._utils import convert_to_colo_tensor
|
||||
|
||||
|
||||
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
|
||||
"""register_colo_graph
|
||||
Register a Op (Layer) to ColoGraph.
|
||||
Recoders the input args in types of ColoTensor to the Graph.
|
||||
Args:
|
||||
func (Callable): a function implements the Op.
|
||||
|
||||
Returns:
|
||||
Callable: wrapper function.
|
||||
"""
|
||||
|
||||
def register_colo_graph_decorator(func):
|
||||
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
param_list = []
|
||||
input_list = []
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, torch.Tensor) and idx in input_pos:
|
||||
input_list.append(convert_to_colo_tensor(arg))
|
||||
if isinstance(arg, torch.Tensor) and idx in param_pos:
|
||||
param_list.append(convert_to_colo_tensor(arg))
|
||||
print(f'Op {func}')
|
||||
# building the computing graph, inputs -> op
|
||||
if GraphGlobalEnv().graph_building:
|
||||
cur_op_node = GraphOpNode('linear', param_list)
|
||||
# TODO supports a list of ColoTensor as args
|
||||
if len(input_list) > 0:
|
||||
cur_op_node.add_prev_tensor(input_list[0])
|
||||
|
||||
outputs = func(*args, **kwargs)
|
||||
|
||||
# building the computing graph, op -> output
|
||||
if GraphGlobalEnv().graph_building:
|
||||
# TODO supports a list of ColoTensor as args
|
||||
if isinstance(outputs[0], ColoTensor):
|
||||
cur_op_node.add_post_tensor(outputs[0])
|
||||
return outputs
|
||||
|
||||
return wrapper
|
||||
|
||||
return register_colo_graph_decorator
|
||||
Reference in New Issue
Block a user