mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-25 23:11:55 +00:00
support unet metainfo prop (#2544)
This commit is contained in:
parent
c4b15661d7
commit
fa3d66feb9
@ -164,18 +164,9 @@ def meta_conv(
|
|||||||
|
|
||||||
|
|
||||||
@register_meta(aten._convolution.default)
|
@register_meta(aten._convolution.default)
|
||||||
def meta_conv_1(
|
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||||
input_tensor: torch.Tensor,
|
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||||
weight: torch.Tensor,
|
*extra_args):
|
||||||
bias: torch.Tensor,
|
|
||||||
stride: List[int],
|
|
||||||
padding: List[int],
|
|
||||||
dilation: List[int],
|
|
||||||
is_transposed: bool,
|
|
||||||
output_padding: List[int],
|
|
||||||
groups: int,
|
|
||||||
*extra_args
|
|
||||||
):
|
|
||||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -233,11 +224,8 @@ def meta_cuda_rnn(
|
|||||||
if is_input_packed:
|
if is_input_packed:
|
||||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||||
else:
|
else:
|
||||||
out_shape = (
|
out_shape = ([mini_batch, seq_length, out_size *
|
||||||
[mini_batch, seq_length, out_size * num_directions]
|
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||||
if batch_first
|
|
||||||
else [seq_length, mini_batch, out_size * num_directions]
|
|
||||||
)
|
|
||||||
output = input.new_empty(out_shape)
|
output = input.new_empty(out_shape)
|
||||||
|
|
||||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||||
@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
|||||||
return dX, dgamma, dbeta
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp
|
||||||
|
@register_meta(aten.native_group_norm_backward.default)
|
||||||
|
def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask):
|
||||||
|
dX = torch.empty_like(input)
|
||||||
|
dgamma = torch.empty_like(gamma)
|
||||||
|
dbeta = torch.empty_like(gamma)
|
||||||
|
return dX, dgamma, dbeta
|
||||||
|
|
||||||
|
|
||||||
# ================================== Misc ==========================================
|
# ================================== Misc ==========================================
|
||||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||||
@register_meta(aten.roll.default)
|
@register_meta(aten.roll.default)
|
||||||
|
@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||||
|
"""
|
||||||
|
Count flops for the baddbmm(batch add and batch matmul) operation.
|
||||||
|
"""
|
||||||
|
# Inputs = [input, batch1, batch2]
|
||||||
|
# out = input + batch1 x batch2
|
||||||
|
assert len(inputs) == 3, len(inputs)
|
||||||
|
n, c, t = inputs[1].shape
|
||||||
|
d = inputs[2].shape[-1]
|
||||||
|
flops = n * c * t * d
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
def conv_flop_count(
|
def conv_flop_count(
|
||||||
x_shape: List[int],
|
x_shape: List[int],
|
||||||
w_shape: List[int],
|
w_shape: List[int],
|
||||||
@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
aten.matmul.default: matmul_flop_jit,
|
aten.matmul.default: matmul_flop_jit,
|
||||||
aten.addmm.default: addmm_flop_jit,
|
aten.addmm.default: addmm_flop_jit,
|
||||||
aten.bmm.default: bmm_flop_jit,
|
aten.bmm.default: bmm_flop_jit,
|
||||||
|
aten.baddbmm.default: baddbmm_flop_jit,
|
||||||
|
|
||||||
# convolution
|
# convolution
|
||||||
aten.convolution.default: conv_flop_jit,
|
aten.convolution.default: conv_flop_jit,
|
||||||
@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||||
|
aten.native_group_norm.default: norm_flop_counter(2, 0),
|
||||||
|
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
|
||||||
|
|
||||||
# pooling
|
# pooling
|
||||||
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
||||||
@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||||
aten.embedding.default: elementwise_flop_counter(1, 0),
|
aten.embedding.default: elementwise_flop_counter(1, 0),
|
||||||
|
aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1),
|
||||||
|
aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
elementwise_flop_aten = [
|
elementwise_flop_aten = [
|
||||||
@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||||||
aten.mean.dim,
|
aten.mean.dim,
|
||||||
aten.sub.Tensor,
|
aten.sub.Tensor,
|
||||||
aten.sub_.Tensor,
|
aten.sub_.Tensor,
|
||||||
|
aten.exp.default,
|
||||||
|
aten.sin.default,
|
||||||
|
aten.cos.default,
|
||||||
|
|
||||||
# activation op
|
# activation op
|
||||||
aten.hardswish.default,
|
aten.hardswish.default,
|
||||||
|
Loading…
Reference in New Issue
Block a user