mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
The writing style of tail processing and the logic related to macro definitions have been optimized. (#5519)
This commit is contained in:
@@ -56,21 +56,14 @@
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||
TYPE, NAME, ...) \
|
||||
switch (HIGH_PRECISION) { \
|
||||
case false: { \
|
||||
const bool high_precision = false; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
case true: { \
|
||||
const bool high_precision = true; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||
TYPE, NAME, ...) \
|
||||
if (HIGH_PRECISION) { \
|
||||
const bool high_precision = true; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
} else { \
|
||||
const bool high_precision = false; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
|
@@ -27,17 +27,11 @@ struct MPTypeTrait<at::BFloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
|
||||
template <bool high_precision, typename scalar_t>
|
||||
struct ScalarTypeTrait;
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<true, T> {
|
||||
using Type = typename MPTypeTrait<T>::Type;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScalarTypeTrait<false, T> {
|
||||
using Type = T;
|
||||
template <bool high_precision, typename T>
|
||||
struct ScalarTypeTrait {
|
||||
using Type =
|
||||
typename std::conditional<high_precision, typename MPTypeTrait<T>::Type,
|
||||
T>::type;
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
|
Reference in New Issue
Block a user