[inference]Optimize the usage of the mid tensors space in flash attn (#5304)

* opt flash attn

* opt tmp tensor

* fix benchmark_llama

* fix code style

* fix None logic for output tensor

* fix adapted to get_xine_cache

* add comment

* fix ci bugs

* fix some codes

* rm duplicated codes

* rm duplicated codes

* fix code style

* add _get_dtype in config.py
This commit is contained in:
yuehuayingxueluo
2024-01-26 14:00:10 +08:00
committed by GitHub
parent af8359c430
commit 4f28cb43c0
16 changed files with 199 additions and 57 deletions

View File

@@ -55,6 +55,7 @@ class InferenceConfig:
def __post_init__(self):
self._init_batch_size()
self._verify_config()
self._get_dtype()
def _init_batch_size(self):
"""
@@ -84,6 +85,7 @@ class InferenceConfig:
assert (
self.tp_size * self.pp_size == dist.get_world_size()
), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})"
assert self.dtype in [
"fp16",
"fp32",
@@ -97,3 +99,11 @@ class InferenceConfig:
"gptq",
None,
], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}."
def _get_dtype(self) -> None:
if self.dtype == "fp32" or self.dtype == torch.float32:
self.dtype = torch.float32
elif self.dtype == "fp16" or self.dtype == torch.float16:
self.dtype = torch.float16
else:
self.dtype = torch.bfloat16