[hotfix] avoid conflict of meta registry with torch 1.13.0. (#1530)

* [hotfix] avoid conflict of meta registry with torch 1.13.0.

* [hotfix] avoid conflict of meta registry with torch 1.13.0.
This commit is contained in:
Super Daniel 2022-09-01 15:31:21 +08:00 committed by GitHub
parent b231430bcb
commit 112a1f0a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
aten = torch.ops.aten aten = torch.ops.aten
meta_lib = torch.library.Library("aten", "IMPL", "Meta") meta_lib = torch.library.Library("aten", "IMPL", "Meta")
@ -14,16 +13,17 @@ meta_table = {}
def register_meta(op, register_dispatcher=True): def register_meta(op, register_dispatcher=True):
def wrapper(f): def wrapper(f):
def add_func(op): def add_func(op):
meta_table[op] = f meta_table[op] = f
if register_dispatcher: if register_dispatcher:
name = ( name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
op.__name__ try:
if op._overloadname != "default"
else op.overloadpacket.__name__
)
meta_lib.impl(name, f) meta_lib.impl(name, f)
except:
pass
tree_map(add_func, op) tree_map(add_func, op)
return f return f
@ -44,6 +44,7 @@ def meta_conv(
output_padding: List[int], output_padding: List[int],
groups: int, groups: int,
): ):
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int: def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
""" """
Formula to apply to calculate the length of some dimension of the output Formula to apply to calculate the length of some dimension of the output
@ -120,14 +121,9 @@ def meta_conv(
kernel_size[i], kernel_size[i],
stride[i], stride[i],
output_padding_list[i], output_padding_list[i],
) ))
)
else: else:
ret_shape.append( ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
_formula(
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
)
)
return ret_shape return ret_shape
def pick_memory_format(): def pick_memory_format():
@ -156,9 +152,7 @@ def meta_conv(
out_channels = weight.shape[0] out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups: if weight.shape[1] != input_tensor.shape[1] / groups:
raise RuntimeError("Invalid channel dimensions") raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape( shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
dims, kernel_size, stride, padding, dilation
)
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out)) out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
mem_fmt = pick_memory_format() mem_fmt = pick_memory_format()
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload] out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
@ -166,10 +160,8 @@ def meta_conv(
@register_meta(aten.convolution_backward.default) @register_meta(aten.convolution_backward.default)
def meta_conv_backward( def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, padding, dilation, transposed, output_padding, groups, output_mask):
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
):
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta') return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor):
@register_meta(aten.hardswish_backward.default) @register_meta(aten.hardswish_backward.default)
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor): def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
grad_in = torch.empty_like(input) grad_in = torch.empty_like(input)
return grad_in return grad_in
@register_meta([aten.roll.default, ]) @register_meta(aten.roll.default)
def meta_roll(input:torch.Tensor, shifts, dims): def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input) return torch.empty_like(input)
@register_meta(aten.native_batch_norm.default) @register_meta(aten.native_batch_norm.default)
def meta_bn( def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
input: torch.Tensor,
weight, bias, running_mean, running_var, training, momentum, eps
):
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
@ -208,10 +197,8 @@ def meta_bn(
@register_meta(aten.native_batch_norm_backward.default) @register_meta(aten.native_batch_norm_backward.default)
def meta_bn_backward( def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, save_invstd, train, eps, output_mask):
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
):
dX = torch.empty_like(input) dX = torch.empty_like(input)
dgamma = torch.empty_like(weight) dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(weight) dbeta = torch.empty_like(weight)
@ -219,10 +206,7 @@ def meta_bn_backward(
@register_meta(aten.native_layer_norm.default) @register_meta(aten.native_layer_norm.default)
def meta_ln( def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
input: torch.Tensor,
normalized_shape, weight, bias, eps
):
n_input = input.size(1) n_input = input.size(1)
output = torch.empty_like(input) output = torch.empty_like(input)
@ -232,11 +216,8 @@ def meta_ln(
@register_meta(aten.native_layer_norm_backward.default) @register_meta(aten.native_layer_norm_backward.default)
def meta_ln_backward( def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
dY: torch.Tensor, grad_input_mask):
input: torch.Tensor,
normalized_shape, mean, rstd, weight, bias, grad_input_mask
):
dX = torch.empty_like(input) dX = torch.empty_like(input)
dgamma = torch.empty_like(weight) dgamma = torch.empty_like(weight)
dbeta = torch.empty_like(bias) dbeta = torch.empty_like(bias)
@ -245,7 +226,8 @@ def meta_ln_backward(
@register_meta(aten._adaptive_avg_pool2d_backward.default) @register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta_adaptive_avg_pool2d_backward( def meta_adaptive_avg_pool2d_backward(
grad_output: torch.Tensor, input: torch.Tensor, grad_output: torch.Tensor,
input: torch.Tensor,
): ):
grad_input = torch.empty_like(input) grad_input = torch.empty_like(input)
return torch.empty_like(input) return torch.empty_like(input)
@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
k = len(result) k = len(result)
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}" assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
for j in range(index.ndim): for j in range(index.ndim):
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}" assert index.shape[j] == self.shape[
k +
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
result.append(nonzero.select(1, j)) result.append(nonzero.select(1, j))
else: else:
result.append(index) result.append(index)