mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	[graph] improve the graph building. (#1157)
This commit is contained in:
		| @@ -1,14 +1,14 @@ | ||||
| import torch.nn.functional as F | ||||
| from typing import Optional | ||||
| from ._utils import GeneralTensor, convert_to_colo_tensor | ||||
| from colossalai.tensor.op_wrapper import colo_op_impl | ||||
| from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad | ||||
| from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec | ||||
| from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv | ||||
| from colossalai.context import ParallelMode | ||||
| from ._utils import GeneralTensor, convert_to_colo_tensor | ||||
| from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv | ||||
|  | ||||
|  | ||||
| def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | ||||
| def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||
|     # Input:S[1] x Weight:S[0] = Output:P | ||||
|     # All-Reduce(Output) + bias = res | ||||
|     # Input:S[1] | ||||
| @@ -28,7 +28,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | ||||
| def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||
|     # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] | ||||
|     # All-Gather(Output) | ||||
|     # Input:B | ||||
| @@ -48,23 +48,21 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option | ||||
|     return output | ||||
|  | ||||
|  | ||||
| def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: | ||||
| def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': | ||||
|     assert mode in ('row', 'col') | ||||
|     funcs = {'row': colo_linear_1Drow, 'col': colo_linear_1Dcol} | ||||
|     return funcs[mode](input_tensor, weight, bias) | ||||
|  | ||||
|  | ||||
| @colo_op_impl(F.linear) | ||||
| def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None): | ||||
| @register_colo_graph(input_pos=[1], param_pos=[2, 3]) | ||||
| def colo_linear_imp(input_tensor: GeneralTensor, | ||||
|                     weight: GeneralTensor, | ||||
|                     bias: Optional[GeneralTensor] = None) -> 'ColoTensor': | ||||
|     """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. | ||||
|     This method computes a linear. | ||||
|     """ | ||||
|     input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) | ||||
|  | ||||
|     # building the computing graph, inputs -> op | ||||
|     if GraphGlobalEnv().graph_building: | ||||
|         cur_op_node = GraphOpNode('linear', [weight, bias]) | ||||
|         cur_op_node.add_prev_tensor(input_tensor) | ||||
|     # Add communication logic before and after linear call. | ||||
|     ret_tensor = None | ||||
|     if not weight.has_spec():    # No Model Parallel Applied | ||||
| @@ -82,7 +80,11 @@ def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Option | ||||
|     else: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     # building the computing graph, op -> output | ||||
|     if GraphGlobalEnv().graph_building: | ||||
|         cur_op_node.add_post_tensor(ret_tensor) | ||||
|     return ret_tensor | ||||
|  | ||||
|  | ||||
| @colo_op_impl(F.linear) | ||||
| def colo_linear(input_tensor: GeneralTensor, | ||||
|                 weight: GeneralTensor, | ||||
|                 bias: Optional[GeneralTensor] = None) -> 'ColoTensor': | ||||
|     return colo_linear_imp(input_tensor, weight, bias) | ||||
|   | ||||
							
								
								
									
										4
									
								
								colossalai/nn/graph/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								colossalai/nn/graph/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| from .utils import register_colo_graph | ||||
| from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode | ||||
|  | ||||
| __all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode'] | ||||
| @@ -74,7 +74,6 @@ class GraphOpNode(GraphNode): | ||||
|             assert isinstance(colo_tensor, ColoTensor) | ||||
|             if colo_tensor._graph_node is None: | ||||
|                 colo_tensor._graph_node = GraphNode() | ||||
| 
 | ||||
|             prev_ops = colo_tensor._graph_node.prev_nodes | ||||
|             for op_node in prev_ops: | ||||
|                 self.add_prev_node(op_node) | ||||
| @@ -85,7 +84,7 @@ class GraphOpNode(GraphNode): | ||||
|         Op <- Activation (colo_tensor) | ||||
|         """ | ||||
|         if GraphGlobalEnv().graph_building: | ||||
|             assert isinstance(colo_tensor, ColoTensor) | ||||
|             assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}' | ||||
|             if colo_tensor._graph_node is None: | ||||
|                 colo_tensor._graph_node = GraphNode() | ||||
| 
 | ||||
							
								
								
									
										50
									
								
								colossalai/nn/graph/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								colossalai/nn/graph/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| import functools | ||||
| import torch | ||||
| from colossalai.tensor import ColoTensor | ||||
| from typing import Callable, List | ||||
| from colossalai.nn._ops._utils import convert_to_colo_tensor | ||||
|  | ||||
|  | ||||
| def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable: | ||||
|     """register_colo_graph  | ||||
|     Register a Op (Layer) to ColoGraph. | ||||
|     Recoders the input args in types of ColoTensor to the Graph. | ||||
|     Args: | ||||
|         func (Callable): a function implements the Op. | ||||
|  | ||||
|     Returns: | ||||
|         Callable: wrapper function. | ||||
|     """ | ||||
|  | ||||
|     def register_colo_graph_decorator(func): | ||||
|         from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv | ||||
|  | ||||
|         @functools.wraps(func) | ||||
|         def wrapper(*args, **kwargs): | ||||
|             param_list = [] | ||||
|             input_list = [] | ||||
|             for idx, arg in enumerate(args): | ||||
|                 if isinstance(arg, torch.Tensor) and idx in input_pos: | ||||
|                     input_list.append(convert_to_colo_tensor(arg)) | ||||
|                 if isinstance(arg, torch.Tensor) and idx in param_pos: | ||||
|                     param_list.append(convert_to_colo_tensor(arg)) | ||||
|             print(f'Op {func}') | ||||
|             # building the computing graph, inputs -> op | ||||
|             if GraphGlobalEnv().graph_building: | ||||
|                 cur_op_node = GraphOpNode('linear', param_list) | ||||
|                 # TODO supports a list of ColoTensor as args | ||||
|                 if len(input_list) > 0: | ||||
|                     cur_op_node.add_prev_tensor(input_list[0]) | ||||
|  | ||||
|             outputs = func(*args, **kwargs) | ||||
|  | ||||
|             # building the computing graph, op -> output | ||||
|             if GraphGlobalEnv().graph_building: | ||||
|                 # TODO supports a list of ColoTensor as args | ||||
|                 if isinstance(outputs[0], ColoTensor): | ||||
|                     cur_op_node.add_post_tensor(outputs[0]) | ||||
|             return outputs | ||||
|  | ||||
|         return wrapper | ||||
|  | ||||
|     return register_colo_graph_decorator | ||||
| @@ -17,6 +17,13 @@ class _DistSpec: | ||||
|                  dist_placement_pattern: DistPlacementPattern, | ||||
|                  process_group: Optional[ProcessGroup] = None, | ||||
|                  **meta_info): | ||||
|         """_DistSpec, Distributed Specification | ||||
|  | ||||
|         Args: | ||||
|             dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes. | ||||
|                                                     The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard. | ||||
|             process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None. | ||||
|         """ | ||||
|         self.placement = dist_placement_pattern | ||||
|         self.process_group = process_group | ||||
|         for k, v in meta_info.items(): | ||||
| @@ -37,6 +44,7 @@ class _DistSpec: | ||||
|                 res += f'{attr}: {str(getattr(self, attr))}\n\t' | ||||
|         return res | ||||
|  | ||||
|  | ||||
| def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec: | ||||
|     # process_group=None means global process group | ||||
|     return _DistSpec(DistPlacementPattern.REPLICATE, process_group) | ||||
|   | ||||
| @@ -1,3 +0,0 @@ | ||||
| from .graph_node import GraphNode, GraphOpNode, GraphContext, GraphGlobalEnv | ||||
|  | ||||
| __all__ = ['GraphNode', 'GraphOpNode', 'GraphContext', 'GraphGlobalEnv'] | ||||
| @@ -1,84 +0,0 @@ | ||||
| import pytest | ||||
| from torch import nn | ||||
| import torch | ||||
| from colossalai.tensor import ColoTensor | ||||
| from colossalai.tensor.graph import GraphContext | ||||
| import gc | ||||
|  | ||||
|  | ||||
| class SimpleNet(nn.Module): | ||||
|  | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.proj1 = nn.Linear(4, 8) | ||||
|         self.proj2 = nn.Linear(8, 4) | ||||
|         self.proj3 = nn.Linear(4, 4) | ||||
|         self.proj4 = nn.Linear(4, 4) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         x = self.proj1(x) | ||||
|         x = self.proj2(x) | ||||
|         x = self.proj3(x) | ||||
|         x = self.proj4(x) | ||||
|         return x | ||||
|  | ||||
|  | ||||
| def _visit_graph(start_node): | ||||
|     if start_node is None: | ||||
|         return | ||||
|  | ||||
|     start_node.print() | ||||
|  | ||||
|     post_node_list = start_node.post_nodes | ||||
|     for node in post_node_list: | ||||
|         _visit_graph(node) | ||||
|  | ||||
|  | ||||
| def _get_tensors(): | ||||
|     for obj in gc.get_objects(): | ||||
|         try: | ||||
|             if torch.is_tensor(obj): | ||||
|                 yield obj | ||||
|         except Exception as e: | ||||
|             print('A trivial exception occured: {}'.format(e)) | ||||
|  | ||||
|  | ||||
| def _count_tensors(): | ||||
|     cnt = 0 | ||||
|     for t in _get_tensors(): | ||||
|         cnt += 1 | ||||
|     return cnt | ||||
|  | ||||
|  | ||||
| def count_tensors(use_colossal): | ||||
|     model = SimpleNet() | ||||
|  | ||||
|     model.eval() | ||||
|     with torch.no_grad(): | ||||
|         if use_colossal: | ||||
|             colo_input = ColoTensor.from_torch_tensor(torch.randn(4)) | ||||
|             graph_ctx = GraphContext() | ||||
|             with graph_ctx: | ||||
|                 output = model(colo_input) | ||||
|             output = model(colo_input) | ||||
|             ret = _count_tensors() | ||||
|  | ||||
|             _visit_graph(graph_ctx.graph_nodes[0]) | ||||
|  | ||||
|             del graph_ctx | ||||
|             return ret | ||||
|         else: | ||||
|             input_t = torch.randn(4) | ||||
|             output = model(input_t) | ||||
|             output = model(input_t) | ||||
|             return _count_tensors() | ||||
|  | ||||
|  | ||||
| @pytest.mark.skip | ||||
| # FIXME(ver217) | ||||
| def test_check_activation_tensors(): | ||||
|     assert count_tensors(False) == count_tensors(True) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     count_tensors(True) | ||||
		Reference in New Issue
	
	Block a user