mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -51,17 +51,10 @@ class InferenceEngine:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
@@ -217,6 +210,7 @@ class InferenceEngine:
|
||||
None,
|
||||
block_table,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.tokenizer.pad_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
@@ -241,7 +235,6 @@ class InferenceEngine:
|
||||
batch,
|
||||
self.k_cahce,
|
||||
self.v_cache,
|
||||
padding_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
Reference in New Issue
Block a user