mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
# for more meta_registrations
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
@@ -16,13 +16,11 @@ meta_table = {}
|
||||
|
||||
|
||||
def register_meta(op, register_dispatcher=True):
|
||||
|
||||
def wrapper(f):
|
||||
|
||||
def add_func(op):
|
||||
meta_table[op] = f
|
||||
if register_dispatcher:
|
||||
name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
|
||||
name = op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__
|
||||
try:
|
||||
meta_lib.impl(name, f)
|
||||
except:
|
||||
@@ -48,7 +46,6 @@ def meta_conv(
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
):
|
||||
|
||||
def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
|
||||
"""
|
||||
Formula to apply to calculate the length of some dimension of the output
|
||||
@@ -125,7 +122,8 @@ def meta_conv(
|
||||
kernel_size[i],
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
|
||||
return ret_shape
|
||||
@@ -159,22 +157,42 @@ def meta_conv(
|
||||
shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
|
||||
out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
|
||||
mem_fmt = pick_memory_format()
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
def meta_conv_1(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
*extra_args,
|
||||
):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device='meta')
|
||||
def meta_conv_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_sizes,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
output_mask,
|
||||
):
|
||||
return torch.empty_like(input), torch.empty_like(weight), torch.empty((bias_sizes), device="meta")
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
|
||||
@@ -208,7 +226,6 @@ def meta_cuda_rnn(
|
||||
batch_sizes,
|
||||
dropout_state,
|
||||
):
|
||||
|
||||
is_input_packed = len(batch_sizes) != 0
|
||||
if is_input_packed:
|
||||
seq_length = len(batch_sizes)
|
||||
@@ -224,8 +241,11 @@ def meta_cuda_rnn(
|
||||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
out_shape = (
|
||||
[mini_batch, seq_length, out_size * num_directions]
|
||||
if batch_first
|
||||
else [seq_length, mini_batch, out_size * num_directions]
|
||||
)
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
@@ -242,18 +262,20 @@ def meta_cuda_rnn(
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
|
||||
@register_meta(aten._cudnn_rnn_backward.default)
|
||||
def meta_cudnn_rnn_backward(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
def meta_cudnn_rnn_backward(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_stride0: int,
|
||||
hx: torch.Tensor,
|
||||
cx: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
print(input, weight, hx, cx)
|
||||
grad_input = torch.empty_like(input)
|
||||
grad_weight = torch.empty_like(weight)
|
||||
grad_hx = torch.empty_like(hx)
|
||||
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device='meta')
|
||||
grad_cx = torch.empty_like(cx) if cx is not None else torch.empty((), device="meta")
|
||||
return grad_input, grad_weight, grad_hx, grad_cx
|
||||
|
||||
|
||||
@@ -298,15 +320,25 @@ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, traini
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((n_input), device='meta')
|
||||
running_var = torch.empty((n_input), device='meta')
|
||||
running_mean = torch.empty((n_input), device="meta")
|
||||
running_var = torch.empty((n_input), device="meta")
|
||||
return output, running_mean, running_var
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
|
||||
@register_meta(aten.native_batch_norm_backward.default)
|
||||
def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var, save_mean,
|
||||
save_invstd, train, eps, output_mask):
|
||||
def meta_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
train,
|
||||
eps,
|
||||
output_mask,
|
||||
):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(weight)
|
||||
@@ -319,9 +351,9 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((n_input), device='meta')
|
||||
running_var = torch.empty((n_input), device='meta')
|
||||
reserve = torch.empty((0), dtype=torch.uint8, device='meta')
|
||||
running_mean = torch.empty((n_input), device="meta")
|
||||
running_var = torch.empty((n_input), device="meta")
|
||||
reserve = torch.empty((0), dtype=torch.uint8, device="meta")
|
||||
return output, running_mean, running_var, reserve
|
||||
|
||||
|
||||
@@ -330,8 +362,17 @@ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var,
|
||||
# in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
# which is why this doesn't accept a 'training' parameter.
|
||||
@register_meta(aten.cudnn_batch_norm_backward.default)
|
||||
def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
|
||||
save_mean, save_invstd, eps, reserve):
|
||||
def meta_cudnn_bn_backward(
|
||||
dY: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
running_mean,
|
||||
running_var,
|
||||
save_mean,
|
||||
save_invstd,
|
||||
eps,
|
||||
reserve,
|
||||
):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(weight)
|
||||
@@ -345,15 +386,16 @@ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
|
||||
n_input = input.size(1)
|
||||
|
||||
output = torch.empty_like(input)
|
||||
running_mean = torch.empty((bs, n_input, 1), device='meta')
|
||||
running_var = torch.empty((bs, n_input, 1), device='meta')
|
||||
running_mean = torch.empty((bs, n_input, 1), device="meta")
|
||||
running_var = torch.empty((bs, n_input, 1), device="meta")
|
||||
return output, running_mean, running_var
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
|
||||
@register_meta(aten.native_layer_norm_backward.default)
|
||||
def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
|
||||
grad_input_mask):
|
||||
def meta_ln_backward(
|
||||
dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias, grad_input_mask
|
||||
):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(weight)
|
||||
dbeta = torch.empty_like(bias)
|
||||
@@ -397,16 +439,19 @@ def meta_index_Tensor(self, indices):
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
assert index.dtype in [
|
||||
torch.long,
|
||||
torch.int8,
|
||||
torch.bool,
|
||||
], "tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
assert (
|
||||
index.shape[j] == self.shape[k + j]
|
||||
), f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
@@ -482,12 +527,15 @@ def meta_index_Tensor(self, indices):
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return torch.empty((num_weights, grad_output.size(-1)),
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
layout=grad_output.layout)
|
||||
def meta_embedding_dense_backward(
|
||||
grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx, scale_grad_by_freq
|
||||
):
|
||||
return torch.empty(
|
||||
(num_weights, grad_output.size(-1)),
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
layout=grad_output.layout,
|
||||
)
|
||||
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
|
Reference in New Issue
Block a user