mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[inference]Optimize the usage of the mid tensors space in flash attn (#5304)
* opt flash attn * opt tmp tensor * fix benchmark_llama * fix code style * fix None logic for output tensor * fix adapted to get_xine_cache * add comment * fix ci bugs * fix some codes * rm duplicated codes * rm duplicated codes * fix code style * add _get_dtype in config.py
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.kv_cache import KVCacheManager
|
||||
from colossalai.inference.logit_processors import logit_processor
|
||||
from colossalai.inference.sampler import *
|
||||
@@ -69,20 +70,60 @@ class RequestHandler:
|
||||
Args:
|
||||
inference_config: Configuration for initialize and manage kv cache.
|
||||
model_config: Configuration for model
|
||||
dtype (torch.dtype): The data type for weights and activations.
|
||||
"""
|
||||
|
||||
def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None:
|
||||
self.inference_config = inference_config
|
||||
self._init_cache(model_config)
|
||||
|
||||
self.running_list: RunningList = RunningList(inference_config.prefill_ratio)
|
||||
self.waiting_list: List[List] = [[], [], []]
|
||||
self.done_list: List[Sequence] = []
|
||||
device = torch.cuda.current_device()
|
||||
self.running_batch = BatchInfo(is_prompts=False, device=device)
|
||||
self.prefill_batch = BatchInfo(is_prompts=True, device=device)
|
||||
self.dtype = inference_config.dtype
|
||||
self.max_batch_size = inference_config.max_batch_size
|
||||
|
||||
# initialize cache
|
||||
self._init_cache(model_config)
|
||||
|
||||
# initialize batch
|
||||
device = torch.cuda.current_device()
|
||||
kv_max_split_num = (
|
||||
inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1
|
||||
) // inference_config.block_size
|
||||
head_dim = model_config.hidden_size // model_config.num_attention_heads
|
||||
|
||||
fd_inter_tensor = FDIntermTensors()
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=self.max_batch_size,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=False,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
)
|
||||
self.prefill_batch = BatchInfo(
|
||||
max_batch_size=self.max_batch_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
num_heads=model_config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
is_prompts=True,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
)
|
||||
|
||||
def _init_cache(self, model_config):
|
||||
self.cache_manager = KVCacheManager(self.inference_config, model_config)
|
||||
|
||||
|
Reference in New Issue
Block a user