mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-04 17:49:48 +00:00
[hotfix] avoid conflict of meta registry with torch 1.13.0. (#1530)
* [hotfix] avoid conflict of meta registry with torch 1.13.0. * [hotfix] avoid conflict of meta registry with torch 1.13.0.
This commit is contained in:
parent
b231430bcb
commit
112a1f0a8f
@ -5,7 +5,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
|
||||||
aten = torch.ops.aten
|
aten = torch.ops.aten
|
||||||
|
|
||||||
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
|
||||||
@ -14,16 +13,17 @@ meta_table = {}
|
|||||||
|
|
||||||
|
|
||||||
def register_meta(op, register_dispatcher=True):
|
def register_meta(op, register_dispatcher=True):
|
||||||
|
|
||||||
def wrapper(f):
|
def wrapper(f):
|
||||||
|
|
||||||
def add_func(op):
|
def add_func(op):
|
||||||
meta_table[op] = f
|
meta_table[op] = f
|
||||||
if register_dispatcher:
|
if register_dispatcher:
|
||||||
name = (
|
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||||
op.__name__
|
try:
|
||||||
if op._overloadname != "default"
|
meta_lib.impl(name, f)
|
||||||
else op.overloadpacket.__name__
|
except:
|
||||||
)
|
pass
|
||||||
meta_lib.impl(name, f)
|
|
||||||
|
|
||||||
tree_map(add_func, op)
|
tree_map(add_func, op)
|
||||||
return f
|
return f
|
||||||
@ -44,6 +44,7 @@ def meta_conv(
|
|||||||
output_padding: List[int],
|
output_padding: List[int],
|
||||||
groups: int,
|
groups: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||||
"""
|
"""
|
||||||
Formula to apply to calculate the length of some dimension of the output
|
Formula to apply to calculate the length of some dimension of the output
|
||||||
@ -120,14 +121,9 @@ def meta_conv(
|
|||||||
kernel_size[i],
|
kernel_size[i],
|
||||||
stride[i],
|
stride[i],
|
||||||
output_padding_list[i],
|
output_padding_list[i],
|
||||||
)
|
))
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ret_shape.append(
|
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||||
_formula(
|
|
||||||
dims[i], padding[i], dilation[i], kernel_size[i], stride[i]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return ret_shape
|
return ret_shape
|
||||||
|
|
||||||
def pick_memory_format():
|
def pick_memory_format():
|
||||||
@ -156,20 +152,16 @@ def meta_conv(
|
|||||||
out_channels = weight.shape[0]
|
out_channels = weight.shape[0]
|
||||||
if weight.shape[1] != input_tensor.shape[1] / groups:
|
if weight.shape[1] != input_tensor.shape[1] / groups:
|
||||||
raise RuntimeError("Invalid channel dimensions")
|
raise RuntimeError("Invalid channel dimensions")
|
||||||
shape_out = calc_conv_nd_return_shape(
|
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||||
dims, kernel_size, stride, padding, dilation
|
|
||||||
)
|
|
||||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||||
mem_fmt = pick_memory_format()
|
mem_fmt = pick_memory_format()
|
||||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.convolution_backward.default)
|
@register_meta(aten.convolution_backward.default)
|
||||||
def meta_conv_backward(
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||||
grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||||
bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask
|
|
||||||
):
|
|
||||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||||
|
|
||||||
|
|
||||||
@ -184,21 +176,18 @@ def meta_hardswish(input: torch.Tensor):
|
|||||||
|
|
||||||
|
|
||||||
@register_meta(aten.hardswish_backward.default)
|
@register_meta(aten.hardswish_backward.default)
|
||||||
def meta_hardswish_backward(grad_out:torch.Tensor, input: torch.Tensor):
|
def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
|
||||||
grad_in = torch.empty_like(input)
|
grad_in = torch.empty_like(input)
|
||||||
return grad_in
|
return grad_in
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten.roll.default, ])
|
@register_meta(aten.roll.default)
|
||||||
def meta_roll(input:torch.Tensor, shifts, dims):
|
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_batch_norm.default)
|
@register_meta(aten.native_batch_norm.default)
|
||||||
def meta_bn(
|
def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
|
||||||
input: torch.Tensor,
|
|
||||||
weight, bias, running_mean, running_var, training, momentum, eps
|
|
||||||
):
|
|
||||||
n_input = input.size(1)
|
n_input = input.size(1)
|
||||||
|
|
||||||
output = torch.empty_like(input)
|
output = torch.empty_like(input)
|
||||||
@ -208,10 +197,8 @@ def meta_bn(
|
|||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_batch_norm_backward.default)
|
@register_meta(aten.native_batch_norm_backward.default)
|
||||||
def meta_bn_backward(
|
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||||
dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
save_invstd, train, eps, output_mask):
|
||||||
running_mean, running_var, save_mean, save_invstd, train, eps, output_mask
|
|
||||||
):
|
|
||||||
dX = torch.empty_like(input)
|
dX = torch.empty_like(input)
|
||||||
dgamma = torch.empty_like(weight)
|
dgamma = torch.empty_like(weight)
|
||||||
dbeta = torch.empty_like(weight)
|
dbeta = torch.empty_like(weight)
|
||||||
@ -219,10 +206,7 @@ def meta_bn_backward(
|
|||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_layer_norm.default)
|
@register_meta(aten.native_layer_norm.default)
|
||||||
def meta_ln(
|
def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||||
input: torch.Tensor,
|
|
||||||
normalized_shape, weight, bias, eps
|
|
||||||
):
|
|
||||||
n_input = input.size(1)
|
n_input = input.size(1)
|
||||||
|
|
||||||
output = torch.empty_like(input)
|
output = torch.empty_like(input)
|
||||||
@ -232,11 +216,8 @@ def meta_ln(
|
|||||||
|
|
||||||
|
|
||||||
@register_meta(aten.native_layer_norm_backward.default)
|
@register_meta(aten.native_layer_norm_backward.default)
|
||||||
def meta_ln_backward(
|
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||||
dY: torch.Tensor,
|
grad_input_mask):
|
||||||
input: torch.Tensor,
|
|
||||||
normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
|
||||||
):
|
|
||||||
dX = torch.empty_like(input)
|
dX = torch.empty_like(input)
|
||||||
dgamma = torch.empty_like(weight)
|
dgamma = torch.empty_like(weight)
|
||||||
dbeta = torch.empty_like(bias)
|
dbeta = torch.empty_like(bias)
|
||||||
@ -245,7 +226,8 @@ def meta_ln_backward(
|
|||||||
|
|
||||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||||
def meta_adaptive_avg_pool2d_backward(
|
def meta_adaptive_avg_pool2d_backward(
|
||||||
grad_output: torch.Tensor, input: torch.Tensor,
|
grad_output: torch.Tensor,
|
||||||
|
input: torch.Tensor,
|
||||||
):
|
):
|
||||||
grad_input = torch.empty_like(input)
|
grad_input = torch.empty_like(input)
|
||||||
return torch.empty_like(input)
|
return torch.empty_like(input)
|
||||||
@ -266,7 +248,9 @@ def meta_index_Tensor(self, indices):
|
|||||||
k = len(result)
|
k = len(result)
|
||||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||||
for j in range(index.ndim):
|
for j in range(index.ndim):
|
||||||
assert index.shape[j] == self.shape[k + j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
assert index.shape[j] == self.shape[
|
||||||
|
k +
|
||||||
|
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||||
result.append(nonzero.select(1, j))
|
result.append(nonzero.select(1, j))
|
||||||
else:
|
else:
|
||||||
result.append(index)
|
result.append(index)
|
||||||
@ -275,7 +259,7 @@ def meta_index_Tensor(self, indices):
|
|||||||
indices = result
|
indices = result
|
||||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||||
# expand_outplace
|
# expand_outplace
|
||||||
import torch._refs as refs # avoid import cycle in mypy
|
import torch._refs as refs # avoid import cycle in mypy
|
||||||
|
|
||||||
indices = list(refs._maybe_broadcast(*indices))
|
indices = list(refs._maybe_broadcast(*indices))
|
||||||
# add missing null tensors
|
# add missing null tensors
|
||||||
|
Loading…
Reference in New Issue
Block a user