mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	[graph] improve the graph building. (#1157)
This commit is contained in:
		| @@ -1,14 +1,14 @@ | |||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  | from ._utils import GeneralTensor, convert_to_colo_tensor | ||||||
| from colossalai.tensor.op_wrapper import colo_op_impl | from colossalai.tensor.op_wrapper import colo_op_impl | ||||||
| from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad | from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad | ||||||
| from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec | from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec | ||||||
| from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv |  | ||||||
| from colossalai.context import ParallelMode | from colossalai.context import ParallelMode | ||||||
| from ._utils import GeneralTensor, convert_to_colo_tensor | from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv | ||||||
|  |  | ||||||
|  |  | ||||||
| def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||||
|     # Input:S[1] x Weight:S[0] = Output:P |     # Input:S[1] x Weight:S[0] = Output:P | ||||||
|     # All-Reduce(Output) + bias = res |     # All-Reduce(Output) + bias = res | ||||||
|     # Input:S[1] |     # Input:S[1] | ||||||
| @@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option | |||||||
|     return output |     return output | ||||||
|  |  | ||||||
|  |  | ||||||
| def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||||
|     # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] |     # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] | ||||||
|     # All-Gather(Output) |     # All-Gather(Output) | ||||||
|     # Input:B |     # Input:B | ||||||
| @@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option | |||||||
|     return output |     return output | ||||||
|  |  | ||||||
|  |  | ||||||
| def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||||
|     assert mode in ('row', 'col') |     assert mode in ('row', 'col') | ||||||
|     funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol} |     funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol} | ||||||
|     return funcs[mode](input_tensor, weight, bias) |     return funcs[mode](input_tensor, weight, bias) | ||||||
|  |  | ||||||
|  |  | ||||||
| @colo_op_impl(F.linear) | @register_colo_graph(input_pos=[1], param_pos=[2, 3]) | ||||||
| def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None): | def colo_linear_imp(input_tensor: GeneralTensor, | ||||||
|  |                     weight: GeneralTensor, | ||||||
|  |                     bias: Optional[GeneralTensor] = None) -> 'ColoTensor': | ||||||
|     """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. |     """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. | ||||||
|     This method computes a linear. |     This method computes a linear. | ||||||
|     """ |     """ | ||||||
|     input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) |     input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) | ||||||
|  |  | ||||||
|     # building the computing graph, inputs -> op |  | ||||||
|     if GraphGlobalEnv().graph_building: |  | ||||||
|         cur_op_node = GraphOpNode('linear', [weight, bias]) |  | ||||||
|         cur_op_node.add_prev_tensor(input_tensor) |  | ||||||
|     # Add communication logic before and after linear call. |     # Add communication logic before and after linear call. | ||||||
|     ret_tensor = None |     ret_tensor = None | ||||||
|     if not weight.has_spec():    # No Model Parallel Applied |     if not weight.has_spec():    # No Model Parallel Applied | ||||||
| @@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option | |||||||
|     else: |     else: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     # building the computing graph, op -> output |  | ||||||
|     if GraphGlobalEnv().graph_building: |  | ||||||
|         cur_op_node.add_post_tensor(ret_tensor) |  | ||||||
|     return ret_tensor |     return ret_tensor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @colo_op_impl(F.linear) | ||||||
|  | def colo_linear(input_tensor: GeneralTensor, | ||||||
|  |                 weight: GeneralTensor, | ||||||
|  |                 bias: Optional[GeneralTensor] = None) -> 'ColoTensor': | ||||||
|  |     return colo_linear_imp(input_tensor, weight, bias) | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								colossalai/nn/graph/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								colossalai/nn/graph/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | from .utils import register_colo_graph | ||||||
|  | from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode | ||||||
|  |  | ||||||
|  | __all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode'] | ||||||
| @@ -74,7 +74,6 @@ class GraphOpNode(GraphNode): | |||||||
|             assert isinstance(colo_tensor, ColoTensor) |             assert isinstance(colo_tensor, ColoTensor) | ||||||
|             if colo_tensor._graph_node is None: |             if colo_tensor._graph_node is None: | ||||||
|                 colo_tensor._graph_node = GraphNode() |                 colo_tensor._graph_node = GraphNode() | ||||||
| 
 |  | ||||||
|             prev_ops = colo_tensor._graph_node.prev_nodes |             prev_ops = colo_tensor._graph_node.prev_nodes | ||||||
|             for op_node in prev_ops: |             for op_node in prev_ops: | ||||||
|                 self.add_prev_node(op_node) |                 self.add_prev_node(op_node) | ||||||
| @@ -85,7 +84,7 @@ class GraphOpNode(GraphNode): | |||||||
|         Op <- Activation (colo_tensor) |         Op <- Activation (colo_tensor) | ||||||
|         """ |         """ | ||||||
|         if GraphGlobalEnv().graph_building: |         if GraphGlobalEnv().graph_building: | ||||||
|             assert isinstance(colo_tensor, ColoTensor) |             assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}' | ||||||
|             if colo_tensor._graph_node is None: |             if colo_tensor._graph_node is None: | ||||||
|                 colo_tensor._graph_node = GraphNode() |                 colo_tensor._graph_node = GraphNode() | ||||||
| 
 | 
 | ||||||
							
								
								
									
										50
									
								
								colossalai/nn/graph/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								colossalai/nn/graph/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  | import functools | ||||||
|  | import torch | ||||||
|  | from colossalai.tensor import ColoTensor | ||||||
|  | from typing import Callable, List | ||||||
|  | from colossalai.nn._ops._utils import convert_to_colo_tensor | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable: | ||||||
|  |     """register_colo_graph  | ||||||
|  |     Register a Op (Layer) to ColoGraph. | ||||||
|  |     Recoders the input args in types of ColoTensor to the Graph. | ||||||
|  |     Args: | ||||||
|  |         func (Callable): a function implements the Op. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         Callable: wrapper function. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def register_colo_graph_decorator(func): | ||||||
|  |         from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv | ||||||
|  |  | ||||||
|  |         @functools.wraps(func) | ||||||
|  |         def wrapper(*args, **kwargs): | ||||||
|  |             param_list = [] | ||||||
|  |             input_list = [] | ||||||
|  |             for idx, arg in enumerate(args): | ||||||
|  |                 if isinstance(arg, torch.Tensor) and idx in input_pos: | ||||||
|  |                     input_list.append(convert_to_colo_tensor(arg)) | ||||||
|  |                 if isinstance(arg, torch.Tensor) and idx in param_pos: | ||||||
|  |                     param_list.append(convert_to_colo_tensor(arg)) | ||||||
|  |             print(f'Op {func}') | ||||||
|  |             # building the computing graph, inputs -> op | ||||||
|  |             if GraphGlobalEnv().graph_building: | ||||||
|  |                 cur_op_node = GraphOpNode('linear', param_list) | ||||||
|  |                 # TODO supports a list of ColoTensor as args | ||||||
|  |                 if len(input_list) > 0: | ||||||
|  |                     cur_op_node.add_prev_tensor(input_list[0]) | ||||||
|  |  | ||||||
|  |             outputs = func(*args, **kwargs) | ||||||
|  |  | ||||||
|  |             # building the computing graph, op -> output | ||||||
|  |             if GraphGlobalEnv().graph_building: | ||||||
|  |                 # TODO supports a list of ColoTensor as args | ||||||
|  |                 if isinstance(outputs[0], ColoTensor): | ||||||
|  |                     cur_op_node.add_post_tensor(outputs[0]) | ||||||
|  |             return outputs | ||||||
|  |  | ||||||
|  |         return wrapper | ||||||
|  |  | ||||||
|  |     return register_colo_graph_decorator | ||||||
| @@ -17,6 +17,13 @@ class _DistSpec: | |||||||
|                  dist_placement_pattern: DistPlacementPattern, |                  dist_placement_pattern: DistPlacementPattern, | ||||||
|                  process_group: Optional[ProcessGroup] = None, |                  process_group: Optional[ProcessGroup] = None, | ||||||
|                  **meta_info): |                  **meta_info): | ||||||
|  |         """_DistSpec, Distributed Specification | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. | ||||||
|  |                                                     The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. | ||||||
|  |             process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. | ||||||
|  |         """ | ||||||
|         self.placement = dist_placement_pattern |         self.placement = dist_placement_pattern | ||||||
|         self.process_group = process_group |         self.process_group = process_group | ||||||
|         for k, v in meta_info.items(): |         for k, v in meta_info.items(): | ||||||
| @@ -37,6 +44,7 @@ class _DistSpec: | |||||||
|                 res += f'{attr}: {str(getattr(self, attr))}\n\t' |                 res += f'{attr}: {str(getattr(self, attr))}\n\t' | ||||||
|         return res |         return res | ||||||
|  |  | ||||||
|  |  | ||||||
| def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: | def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: | ||||||
|     # process_group=None means global process group |     # process_group=None means global process group | ||||||
|     return _DistSpec(DistPlacementPattern.REPLICATE, process_group) |     return _DistSpec(DistPlacementPattern.REPLICATE, process_group) | ||||||
|   | |||||||
| @@ -1,3 +0,0 @@ | |||||||
| from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv |  | ||||||
|  |  | ||||||
| __all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv'] |  | ||||||
| @@ -1,84 +0,0 @@ | |||||||
| import pytest |  | ||||||
| from torch import nn |  | ||||||
| import torch |  | ||||||
| from colossalai.tensor import ColoTensor |  | ||||||
| from colossalai.tensor.graph import GraphContext |  | ||||||
| import gc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SimpleNet(nn.Module): |  | ||||||
|  |  | ||||||
|     def __init__(self) -> None: |  | ||||||
|         super().__init__() |  | ||||||
|         self.proj1 = nn.Linear(4, 8) |  | ||||||
|         self.proj2 = nn.Linear(8, 4) |  | ||||||
|         self.proj3 = nn.Linear(4, 4) |  | ||||||
|         self.proj4 = nn.Linear(4, 4) |  | ||||||
|  |  | ||||||
|     def forward(self, x): |  | ||||||
|         x = self.proj1(x) |  | ||||||
|         x = self.proj2(x) |  | ||||||
|         x = self.proj3(x) |  | ||||||
|         x = self.proj4(x) |  | ||||||
|         return x |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _visit_graph(start_node): |  | ||||||
|     if start_node is None: |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|     start_node.print() |  | ||||||
|  |  | ||||||
|     post_node_list = start_node.post_nodes |  | ||||||
|     for node in post_node_list: |  | ||||||
|         _visit_graph(node) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_tensors(): |  | ||||||
|     for obj in gc.get_objects(): |  | ||||||
|         try: |  | ||||||
|             if torch.is_tensor(obj): |  | ||||||
|                 yield obj |  | ||||||
|         except Exception as e: |  | ||||||
|             print('A trivial exception occured: {}'.format(e)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _count_tensors(): |  | ||||||
|     cnt = 0 |  | ||||||
|     for t in _get_tensors(): |  | ||||||
|         cnt += 1 |  | ||||||
|     return cnt |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def count_tensors(use_colossal): |  | ||||||
|     model = SimpleNet() |  | ||||||
|  |  | ||||||
|     model.eval() |  | ||||||
|     with torch.no_grad(): |  | ||||||
|         if use_colossal: |  | ||||||
|             colo_input = ColoTensor.from_torch_tensor(torch.randn(4)) |  | ||||||
|             graph_ctx = GraphContext() |  | ||||||
|             with graph_ctx: |  | ||||||
|                 output = model(colo_input) |  | ||||||
|             output = model(colo_input) |  | ||||||
|             ret = _count_tensors() |  | ||||||
|  |  | ||||||
|             _visit_graph(graph_ctx.graph_nodes[0]) |  | ||||||
|  |  | ||||||
|             del graph_ctx |  | ||||||
|             return ret |  | ||||||
|         else: |  | ||||||
|             input_t = torch.randn(4) |  | ||||||
|             output = model(input_t) |  | ||||||
|             output = model(input_t) |  | ||||||
|             return _count_tensors() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.skip |  | ||||||
| # FIXME(ver217) |  | ||||||
| def test_check_activation_tensors(): |  | ||||||
|     assert count_tensors(False) == count_tensors(True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     count_tensors(True) |  | ||||||
		Reference in New Issue
	
	Block a user