From e0da01ea7143e9e9cd2c1cc30b1599d8aff70c14 Mon Sep 17 00:00:00 2001 From: xcnick Date: Tue, 8 Nov 2022 09:40:24 +0800 Subject: [PATCH] [hotfix] fix build error when torch version >= 1.13 (#1803) --- .../kernel/cuda_native/csrc/multihead_attention_1d.cpp | 5 +++++ .../kernel/cuda_native/csrc/multihead_attention_1d.h | 8 +++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp index b02556f79..166c698f6 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp @@ -2,8 +2,13 @@ #include #include +#include +#if TORCH_VERSION_MINOR >= 13 +#include +#else #include +#endif #include #include "context.h" diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h index 70b3419d8..db50071b6 100644 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h @@ -4,8 +4,14 @@ #include #include #include +#include +#if TORCH_VERSION_MINOR >= 13 +#include +#else #include +#endif + #include #include @@ -157,4 +163,4 @@ class MultiHeadAttention { c10::intrusive_ptr pg; int pg_size; -}; \ No newline at end of file +};