diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 126403270..6caad920d 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -2,11 +2,21 @@ from typing import Callable import torch -try: - from . import _meta_registrations - META_COMPATIBILITY = True -except: +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False +elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: + from . import _meta_regist_12 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: + from . import _meta_regist_13 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 2: + from . import _meta_regist_13 + META_COMPATIBILITY = True + raise UserWarning("Colossalai is not tested with torch2.0 yet!!!") def compatibility(is_backward_compatible: bool = False) -> Callable: diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_regist_12.py similarity index 100% rename from colossalai/fx/_meta_registrations.py rename to colossalai/fx/_meta_regist_12.py diff --git a/colossalai/fx/_meta_regist_13.py b/colossalai/fx/_meta_regist_13.py new file mode 100644 index 000000000..6caa87c44 --- /dev/null +++ b/colossalai/fx/_meta_regist_13.py @@ -0,0 +1,57 @@ +import torch +from torch._meta_registrations import register_meta +from torch._prims_common import check + +aten = torch.ops.aten + + +# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops +# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +@register_meta([aten.convolution_backward.default]) +def meta_convolution_backward( + grad_output_, + input_, + weight_, + bias_sizes_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + # High level logic taken from slow_conv3d_backward_cpu which should + # be representative of all convolution_backward impls + backend_grad_input = None + backend_grad_weight = None + backend_grad_bias = None + + if output_mask[0]: + backend_grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1]: + backend_grad_weight = grad_output_.new_empty(weight_.size()) + if output_mask[2]: + backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) + + return (backend_grad_input, backend_grad_weight, backend_grad_bias) + + +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + return self.new_empty(self.shape)