[graph] improve the graph building. (#1157)

This commit is contained in:
Jiarui Fang
2022-06-22 16:47:20 +08:00
committed by GitHub
parent 22717a856f
commit 07f9c781f9
7 changed files with 79 additions and 103 deletions

View File

@@ -1,14 +1,14 @@
import torch.nn.functional as F
from typing import Optional
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
@@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
@@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
assert mode in ('row', 'col')
funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol}
return funcs[mode](input_tensor, weight, bias)
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None):
@register_colo_graph(input_pos=[1], param_pos=[2, 3])
def colo_linear_imp(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a linear.
"""
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', [weight, bias])
cur_op_node.add_prev_tensor(input_tensor)
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
@@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option
else:
raise NotImplementedError
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
cur_op_node.add_post_tensor(ret_tensor)
return ret_tensor
@colo_op_impl(F.linear)
def colo_linear(input_tensor: GeneralTensor,
weight: GeneralTensor,
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
return colo_linear_imp(input_tensor, weight, bias)

View File

@@ -0,0 +1,4 @@
from .utils import register_colo_graph
from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode
__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode']

View File

@@ -74,7 +74,6 @@ class GraphOpNode(GraphNode):
assert isinstance(colo_tensor, ColoTensor)
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()
prev_ops = colo_tensor._graph_node.prev_nodes
for op_node in prev_ops:
self.add_prev_node(op_node)
@@ -85,7 +84,7 @@ class GraphOpNode(GraphNode):
Op <- Activation (colo_tensor)
"""
if GraphGlobalEnv().graph_building:
assert isinstance(colo_tensor, ColoTensor)
assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}'
if colo_tensor._graph_node is None:
colo_tensor._graph_node = GraphNode()

View File

@@ -0,0 +1,50 @@
import functools
import torch
from colossalai.tensor import ColoTensor
from typing import Callable, List
from colossalai.nn._ops._utils import convert_to_colo_tensor
def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable:
"""register_colo_graph
Register a Op (Layer) to ColoGraph.
Recoders the input args in types of ColoTensor to the Graph.
Args:
func (Callable): a function implements the Op.
Returns:
Callable: wrapper function.
"""
def register_colo_graph_decorator(func):
from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv
@functools.wraps(func)
def wrapper(*args, **kwargs):
param_list = []
input_list = []
for idx, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and idx in input_pos:
input_list.append(convert_to_colo_tensor(arg))
if isinstance(arg, torch.Tensor) and idx in param_pos:
param_list.append(convert_to_colo_tensor(arg))
print(f'Op {func}')
# building the computing graph, inputs -> op
if GraphGlobalEnv().graph_building:
cur_op_node = GraphOpNode('linear', param_list)
# TODO supports a list of ColoTensor as args
if len(input_list) > 0:
cur_op_node.add_prev_tensor(input_list[0])
outputs = func(*args, **kwargs)
# building the computing graph, op -> output
if GraphGlobalEnv().graph_building:
# TODO supports a list of ColoTensor as args
if isinstance(outputs[0], ColoTensor):
cur_op_node.add_post_tensor(outputs[0])
return outputs
return wrapper
return register_colo_graph_decorator

View File

@@ -17,6 +17,13 @@ class _DistSpec:
dist_placement_pattern: DistPlacementPattern,
process_group: Optional[ProcessGroup] = None,
**meta_info):
"""_DistSpec, Distributed Specification
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
self.process_group = process_group
for k, v in meta_info.items():
@@ -37,6 +44,7 @@ class _DistSpec:
res += f'{attr}: {str(getattr(self, attr))}\n\t'
return res
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
# process_group=None means global process group
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)

View File

@@ -1,3 +0,0 @@
from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv
__all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv']