From 2edbef13cc2f08e3d74ea72a68d2299f3e7cdbb7 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Wed, 23 Nov 2022 10:55:46 +0800 Subject: [PATCH] [fx] add more meta_registry for MetaTensor execution. (#2000) * [sc] add examples for auto checkpoint. * merge upstream * [fx] add more meta_registry for MetaTensor execution. --- colossalai/fx/_meta_registrations.py | 53 +++++++++++++++++-- colossalai/fx/profiler/tensor.py | 10 ++++ colossalai/fx/tracer/_symbolic_trace.py | 24 ++++----- .../tutorial/auto_parallel/bench_utils.py | 4 -- 4 files changed, 70 insertions(+), 21 deletions(-) diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index 94387fbe0..f9100d842 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -3,7 +3,7 @@ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml # for more meta_registrations -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map @@ -179,6 +179,42 @@ def meta_adaptive_avg_pool2d_backward( return grad_input +# ================================ RNN ============================================= +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn.default) +def meta_cuda_rnn( + input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + weight_buf: torch.Tensor, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs, +): + if cx is not None: + return torch.empty_like(input), torch.empty_like(hx), torch.empty_like(cx) + else: + return torch.empty_like(input), torch.empty_like(hx), torch.empty((), device='meta') + + +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp +@register_meta(aten._cudnn_rnn_backward.default) +def meta_cudnn_rnn_backward(input: torch.Tensor, + weight: torch.Tensor, + weight_stride0: int, + hx: torch.Tensor, + cx: Optional[torch.Tensor] = None, + *args, + **kwargs): + print(input, weight, hx, cx) + grad_input = torch.empty_like(input) + grad_weight = torch.empty_like(weight) + grad_hx = torch.empty_like(hx) + grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta') + return grad_input, grad_weight, grad_hx, grad_cx + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp # ============================== Activations ======================================= @register_meta(aten.relu.default) @@ -186,6 +222,11 @@ def meta_relu(input: torch.Tensor): return torch.empty_like(input) +@register_meta(aten.prelu.default) +def meta_prelu(input: torch.Tensor, weight: torch.Tensor): + return torch.empty_like(input) + + @register_meta(aten.hardswish.default) def meta_hardswish(input: torch.Tensor): return torch.empty_like(input) @@ -278,12 +319,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me # ================================== Misc ========================================== -#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) def meta_roll(input: torch.Tensor, shifts, dims): return input +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp +@register_meta(aten._local_scalar_dense.default) +def meta_local_scalar_dense(self: torch.Tensor): + return 0 + + # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp @register_meta(aten.where.self) def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor): @@ -317,7 +364,7 @@ def meta_index_Tensor(self, indices): indices = result assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace - import torch._refs as refs # avoid import cycle in mypy + import torch._refs as refs indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 4e9fb5c8c..43165305f 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -128,3 +128,13 @@ class MetaTensor(torch.Tensor): if device is not None: result = MetaTensor(result, fake_device=device) return result + + def cpu(self, *args, **kwargs): + if self.device.type == 'cpu': + return self.to(*args, **kwargs) + return self.to(*args, device='cpu', **kwargs) + + def cuda(self, *args, **kwargs): + if self.device.type == 'cuda': + return self.to(*args, **kwargs) + return self.to(*args, device='cuda', **kwargs) diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py index 39da62473..bff2f6a10 100644 --- a/colossalai/fx/tracer/_symbolic_trace.py +++ b/colossalai/fx/tracer/_symbolic_trace.py @@ -20,28 +20,25 @@ def symbolic_trace( Given an ``nn.Module`` or function instance ``root``, this function will return a ``ColoGraphModule`` constructed by recording operations seen while tracing through ``root``. - With ``meta_args`` and ``concrete_args``, we can trace the model that are untraceable subject to control flow. - If specified using ``meta_args`` only, the tracing can be done ahead of time. + With ``meta_args``, we can trace the model that are untraceable subject to control flow. If specified using + ``meta_args`` only, the tracing can be done ahead of time. - Note that both ``meta_args`` and ``concrete_args`` are kwargs, which contains the key of the argument's names - and the value of the argument's values. + Note that ``meta_args`` are kwargs, which contains the key of the argument's names and the value of the + argument's values. Uses: >>> model = ... # if this works - >>> gm = symbolic_trace(model) + >>> gm = symbolic_trace(model, concrete_args=concrete_args) # else try this - >>> gm = symbolic_trace(model, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) - - # else try this - >>> gm = symbolic_trace(model, concrete_args={'x': torch.rand(1, 3, 224, 224)}) + >>> gm = symbolic_trace(model, concrete_args=concrete_args, meta_args={'x': torch.rand(1, 3, 224, 224, device='meta')}) Args: root (Union[torch.nn.Module, Callable[..., Any]]): Module or function to be traced and converted into a Graph representation. - concrete_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized. Defaults to None. + concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be used for tracing. meta_args (Optional[Dict[str, Any]], optional): Inputs to be partially specialized, special for ``ColoTracer``. Defaults to None. @@ -52,7 +49,6 @@ def symbolic_trace( This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. """ - tracer = ColoTracer() - graph = tracer.trace(root, concrete_args, meta_args) - name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__) - return ColoGraphModule(tracer.root, graph, name) + graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) + name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + return ColoGraphModule(root, graph, name) diff --git a/examples/tutorial/auto_parallel/bench_utils.py b/examples/tutorial/auto_parallel/bench_utils.py index b4141da24..69859f885 100644 --- a/examples/tutorial/auto_parallel/bench_utils.py +++ b/examples/tutorial/auto_parallel/bench_utils.py @@ -18,13 +18,11 @@ def bench(gm: torch.fx.GraphModule, data_gen: Callable, num_steps: int = 5) -> Tuple[int, int]: """Benchmarking a given graph module - Args: gm (torch.fx.GraphModule): The graph module to benchmark. criterion (torch.nn.Module): Loss function. data_gen (Callable): Data generator. num_steps (int, optional): Number of test steps. Defaults to 5. - Returns: Tuple[int, int]: peak memory in MB and step time in MS. """ @@ -69,7 +67,6 @@ def bench_rotor(gm: torch.fx.GraphModule, start_factor: int = 4) -> Tuple[np.array, list, list]: """Auto Checkpoint Rotor Algorithm benchmarking Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data. - Args: gm (torch.fx.GraphModule): The graph module to benchmark. criterion (torch.nn.Module): Loss function. @@ -79,7 +76,6 @@ def bench_rotor(gm: torch.fx.GraphModule, free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0]. start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor. Defaults to 4. - Returns: Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS). """