mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -55,6 +55,7 @@ class InferenceConfig:
|
||||
def __post_init__(self):
|
||||
self._init_batch_size()
|
||||
self._verify_config()
|
||||
self._get_dtype()
|
||||
|
||||
def _init_batch_size(self):
|
||||
"""
|
||||
@@ -84,6 +85,7 @@ class InferenceConfig:
|
||||
assert (
|
||||
self.tp_size * self.pp_size == dist.get_world_size()
|
||||
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
|
||||
|
||||
assert self.dtype in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
@@ -97,3 +99,11 @@ class InferenceConfig:
|
||||
"gptq",
|
||||
None,
|
||||
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
|
||||
|
||||
def _get_dtype(self) -> None:
|
||||
if self.dtype == "fp32" or self.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif self.dtype == "fp16" or self.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
Reference in New Issue
Block a user