mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32
This commit is contained in:
@@ -56,6 +56,23 @@
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \
|
||||
TYPE, NAME, ...) \
|
||||
switch (HIGH_PRECISION) { \
|
||||
case false: { \
|
||||
const bool high_precision = false; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
case true: { \
|
||||
const bool high_precision = true; \
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||
switch (TYPEIN) { \
|
||||
case at::ScalarType::Float: { \
|
||||
|
Reference in New Issue
Block a user