mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference]Support FP16/BF16 Flash Attention 2 And Add high_precision Flag To Rotary Embedding (#5461)
* Support FP16/BF16 Flash Attention 2 * fix bugs in test_kv_cache_memcpy.py * add context_kv_cache_memcpy_kernel.cu * rm typename MT * add tail process * add high_precision * add high_precision to config.py * rm unused code * change the comment for the high_precision parameter * update test_rotary_embdding_unpad.py * fix vector_copy_utils.h * add comment for self.high_precision when using float32
This commit is contained in:
@@ -56,6 +56,7 @@ class InferenceEngine:
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.high_precision = inference_config.high_precision
|
||||
model = model.eval()
|
||||
model = model.cuda()
|
||||
model.to(self.dtype)
|
||||
@@ -297,6 +298,7 @@ class InferenceEngine:
|
||||
batch,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
self.high_precision,
|
||||
)
|
||||
|
||||
if self.inference_config.pad_input:
|
||||
|
Reference in New Issue
Block a user