diff --git a/colossalai/fx/profiler/_meta_registrations.py b/colossalai/fx/profiler/_meta_registrations.py index 7dd3a21c9..94f559f38 100644 --- a/colossalai/fx/profiler/_meta_registrations.py +++ b/colossalai/fx/profiler/_meta_registrations.py @@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union import torch from torch.utils._pytree import tree_map - aten = torch.ops.aten meta_lib = torch.library.Library("aten", "IMPL", "Meta") @@ -14,16 +13,17 @@ meta_table = {} def register_meta(op, register_dispatcher=True): + def wrapper(f): + def add_func(op): meta_table[op] = f if register_dispatcher: - name = ( - op.__name__ - if op._overloadname != "default" - else op.overloadpacket.__name__ - ) - meta_lib.impl(name, f) + name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) + try: + meta_lib.impl(name, f) + except: + pass tree_map(add_func, op) return f @@ -44,6 +44,7 @@ def meta_conv( output_padding: List[int], groups: int, ): + def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: """ Formula to apply to calculate the length of some dimension of the output @@ -120,14 +121,9 @@ def meta_conv( kernel_size[i], stride[i], output_padding_list[i], - ) - ) + )) else: - ret_shape.append( - _formula( - dims[i], padding[i], dilation[i], kernel_size[i], stride[i] - ) - ) + ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])) return ret_shape def pick_memory_format(): @@ -156,20 +152,16 @@ def meta_conv( out_channels = weight.shape[0] if weight.shape[1] != input_tensor.shape[1] / groups: raise RuntimeError("Invalid channel dimensions") - shape_out = calc_conv_nd_return_shape( - dims, kernel_size, stride, padding, dilation - ) + shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) mem_fmt = pick_memory_format() - out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] + out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] return out @register_meta(aten.convolution_backward.default) -def meta_conv_backward( - grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask -): +def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, + padding, dilation, transposed, output_padding, groups, output_mask): return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') @@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor): @register_meta(aten.hardswish_backward.default) -def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor): +def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor): grad_in = torch.empty_like(input) return grad_in -@register_meta([aten.roll.default, ]) -def meta_roll(input:torch.Tensor, shifts, dims): +@register_meta(aten.roll.default) +def meta_roll(input: torch.Tensor, shifts, dims): return torch.empty_like(input) @register_meta(aten.native_batch_norm.default) -def meta_bn( - input: torch.Tensor, - weight, bias, running_mean, running_var, training, momentum, eps -): +def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps): n_input = input.size(1) output = torch.empty_like(input) @@ -208,10 +197,8 @@ def meta_bn( @register_meta(aten.native_batch_norm_backward.default) -def meta_bn_backward( - dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - running_mean, running_var, save_mean, save_invstd, train, eps, output_mask -): +def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean, + save_invstd, train, eps, output_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(weight) @@ -219,10 +206,7 @@ def meta_bn_backward( @register_meta(aten.native_layer_norm.default) -def meta_ln( - input: torch.Tensor, - normalized_shape, weight, bias, eps -): +def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps): n_input = input.size(1) output = torch.empty_like(input) @@ -232,11 +216,8 @@ def meta_ln( @register_meta(aten.native_layer_norm_backward.default) -def meta_ln_backward( - dY: torch.Tensor, - input: torch.Tensor, - normalized_shape, mean, rstd, weight, bias, grad_input_mask -): +def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, + grad_input_mask): dX = torch.empty_like(input) dgamma = torch.empty_like(weight) dbeta = torch.empty_like(bias) @@ -245,7 +226,8 @@ def meta_ln_backward( @register_meta(aten._adaptive_avg_pool2d_backward.default) def meta_adaptive_avg_pool2d_backward( - grad_output: torch.Tensor, input: torch.Tensor, + grad_output: torch.Tensor, + input: torch.Tensor, ): grad_input = torch.empty_like(input) return torch.empty_like(input) @@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices): k = len(result) assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" for j in range(index.ndim): - assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" + assert index.shape[j] == self.shape[ + k + + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" result.append(nonzero.select(1, j)) else: result.append(index) @@ -275,7 +259,7 @@ def meta_index_Tensor(self, indices): indices = result assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})" # expand_outplace - import torch._refs as refs # avoid import cycle in mypy + import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors