mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32
This commit is contained in:
@@ -55,7 +55,7 @@ class InferenceConfig:
|
||||
pp_size (int): Pipeline parallel size, defaults to 1.
|
||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
@@ -89,6 +89,7 @@ class InferenceConfig:
|
||||
pp_size: int = 1
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
def __post_init__(self):
|
||||
self._verify_config()
|
||||
@@ -108,6 +109,10 @@ class InferenceConfig:
|
||||
self.dtype in _ALLOWED_DTYPES
|
||||
), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}"
|
||||
|
||||
# skip using casting when the data type is float32
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check distributed
|
||||
assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
|
Reference in New Issue
Block a user