[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)

* Support FP16/BF16 Flash Attention 2

* fix bugs in test_kv_cache_memcpy.py

* add context_kv_cache_memcpy_kernel.cu

* rm typename MT

* add tail process

* add high_precision

* add high_precision to config.py

* rm unused code

* change the comment for the high_precision parameter

* update test_rotary_embdding_unpad.py

* fix vector_copy_utils.h

* add comment for self.high_precision when using float32
This commit is contained in:
yuehuayingxueluo
2024-03-25 13:40:34 +08:00
committed by GitHub
parent 7ff42cc06d
commit 87079cffe8
15 changed files with 550 additions and 138 deletions

View File

@@ -55,7 +55,7 @@ class InferenceConfig:
pp_size (int): Pipeline parallel size, defaults to 1.
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
"""
# NOTE: arrange configs according to their importance and frequency of usage
@@ -89,6 +89,7 @@ class InferenceConfig:
pp_size: int = 1
micro_batch_size: int = 1
micro_batch_buffer_size: int = None
high_precision: Optional[bool] = False
def __post_init__(self):
self._verify_config()
@@ -108,6 +109,10 @@ class InferenceConfig:
self.dtype in _ALLOWED_DTYPES
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
# skip using casting when the data type is float32
if self.dtype == torch.float32:
self.high_precision = False
# check distributed
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
self.tp_size * self.pp_size == dist.get_world_size()