diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index f9100d842..d614219db 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -163,6 +163,23 @@ def meta_conv( return out +@register_meta(aten._convolution.default) +def meta_conv_1( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: List[int], + padding: List[int], + dilation: List[int], + is_transposed: bool, + output_padding: List[int], + groups: int, + *extra_args +): + out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) + return out + + @register_meta(aten.convolution_backward.default) def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask):