From fbd2a9e05b44b3e95df35518a8f09e9105313358 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Thu, 30 Mar 2023 11:22:20 +0800 Subject: [PATCH] [hotfix] meta_tensor_compatibility_with_torch2 --- colossalai/fx/_compatibility.py | 2 -- colossalai/fx/profiler/opcount.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 6caad920d..0444a4816 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -14,9 +14,7 @@ elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: from . import _meta_regist_13 META_COMPATIBILITY = True elif TORCH_MAJOR == 2: - from . import _meta_regist_13 META_COMPATIBILITY = True - raise UserWarning("Colossalai is not tested with torch2.0 yet!!!") def compatibility(is_backward_compatible: bool = False) -> Callable: diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 407a6bed5..ba090a2ec 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -223,7 +223,8 @@ def zero_flop_jit(*args): return 0 -if version.parse(torch.__version__) >= version.parse('1.12.0'): +if version.parse(torch.__version__) >= version.parse('1.12.0') and version.parse( + torch.__version__) < version.parse('2.0.0'): flop_mapping = { # gemm, gemv and dot aten.mm.default: matmul_flop_jit,