The writing style of tail processing and the logic related to macro definitions have been optimized. (#5519)

This commit is contained in:
yuehuayingxueluo
2024-03-28 10:42:51 +08:00
committed by GitHub
parent e6496dd371
commit 934e31afb2
5 changed files with 129 additions and 165 deletions

View File

@@ -56,21 +56,14 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
switch (HIGH_PRECISION) { \
case false: { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
case true: { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
break; \
} \
default: \
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
TYPE, NAME, ...) \
if (HIGH_PRECISION) { \
const bool high_precision = true; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
} else { \
const bool high_precision = false; \
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \

View File

@@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
using Type = float;
};
template <bool high_precision, typename scalar_t>
struct ScalarTypeTrait;
template <typename T>
struct ScalarTypeTrait<true, T> {
using Type = typename MPTypeTrait<T>::Type;
};
template <typename T>
struct ScalarTypeTrait<false, T> {
using Type = T;
template <bool high_precision, typename T>
struct ScalarTypeTrait {
using Type =
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
T>::type;
};
} // namespace common