[fx] fixed adapative pooling size concatenation error (#1489)

This commit is contained in:
Frank Lee
2022-08-25 09:05:07 +08:00
committed by GitHub
parent cde7b8a5b8
commit 3da68d6b1b
2 changed files with 98 additions and 17 deletions

View File

@@ -22,7 +22,7 @@ def torch_nn_avgpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@@ -46,7 +46,7 @@ def torch_nn_avgpool2d(self, input):
h_out = math.floor((h_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
result_shape = input.shape[:-2] + (
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
@@ -74,7 +74,7 @@ def torch_nn_avgpool3d(self, input):
h_out = math.floor((h_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - kernel_size[2]) / stride[2] + 1)
result_shape = input.shape[:-3] + (
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
@@ -102,7 +102,7 @@ def torch_nn_maxpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = input.shape[:-1] + (l_out,)
result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta')
@@ -127,7 +127,7 @@ def torch_nn_maxpool2d(self, input):
h_out = math.floor((h_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
w_out = math.floor((w_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
result_shape = input.shape[:-2] + (
result_shape = tuple(input.shape[:-2]) + (
h_out,
w_out,
)
@@ -156,7 +156,7 @@ def torch_nn_maxpool3d(self, input):
h_out = math.floor((h_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
w_out = math.floor((w_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
result_shape = input.shape[:-3] + (
result_shape = tuple(input.shape[:-3]) + (
d_out,
h_out,
w_out,
@@ -167,26 +167,34 @@ def torch_nn_maxpool3d(self, input):
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool1d)
def torch_nn_adapative_pooling_1d(self, input):
result_shape = input.shape[:-1] + (self.output_size,)
assert input.dim() in [2, 3]
if isinstance(self.output_size, int):
output_size = (self.output_size,)
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool2d)
def torch_nn_adapative_pooling_2d(self, input):
result_shape = input.shape[:-2] + (
self.output_size,
self.output_size,
)
assert input.dim() in [3, 4]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 2
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size
return torch.empty(result_shape, device='meta')
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_patched_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_adapative_pooling_3d(self, input):
result_shape = input.shape[:-3] + (
self.output_size,
self.output_size,
self.output_size,
)
return torch.empty(result_shape, device='meta')
assert input.dim() in [4, 5]
if isinstance(self.output_size, int):
output_size = (self.output_size,) * 3
else:
output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size
return torch.empty(result_shape, device='meta')