mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference]Add Streaming LLM (#5745)
* Add Streaming LLM * add some parameters to llama_generation.py * verify streamingllm config * add test_streamingllm.py * modified according to the opinions of review * add Citation * change _block_tables tolist
This commit is contained in:
@@ -166,8 +166,9 @@ class InferenceConfig(RPC_PARAM):
|
||||
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
|
||||
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
|
||||
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
|
||||
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
|
||||
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
|
||||
block_size (int): The number of blocks in a logical block, defaults to 16.
|
||||
@@ -176,10 +177,12 @@ class InferenceConfig(RPC_PARAM):
|
||||
micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid.
|
||||
max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence
|
||||
high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False.
|
||||
ignore_eos(bool): Whether to ignore the EOS token and continue generating tokens when encountering the EOS token.
|
||||
enable_streamingllm(bool): Whether to use StreamingLLM, the relevant algorithms refer to the paper at https://arxiv.org/pdf/2309.17453 for implementation.
|
||||
start_token_size(int): The size of the start tokens, when using StreamingLLM.
|
||||
generated_token_size(int): The size of the generated tokens, When using StreamingLLM.
|
||||
"""
|
||||
|
||||
# NOTE: arrange configs according to their importance and frequency of usage
|
||||
@@ -208,6 +211,7 @@ class InferenceConfig(RPC_PARAM):
|
||||
no_repeat_ngram_size: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
forced_eos_token_id: int = None
|
||||
ignore_eos: bool = False
|
||||
|
||||
# speculative decoding configs
|
||||
max_n_spec_tokens: int = 5
|
||||
@@ -221,15 +225,19 @@ class InferenceConfig(RPC_PARAM):
|
||||
pp_size: int = 1
|
||||
micro_batch_size: int = 1
|
||||
micro_batch_buffer_size: int = None
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda kernel option
|
||||
use_cuda_kernel: bool = False
|
||||
high_precision: Optional[bool] = False
|
||||
|
||||
# cuda_graph
|
||||
use_cuda_graph: bool = False # NOTE only when we have the graph for specific decoding batch size can we use the cuda graph for inference
|
||||
max_context_len_to_capture: int = 512
|
||||
ignore_eos: bool = False
|
||||
|
||||
# StreamingLLM (sliding window attention with attention sinks)
|
||||
enable_streamingllm: bool = False
|
||||
start_token_size: int = 4
|
||||
generated_token_size: int = 512
|
||||
|
||||
def __post_init__(self):
|
||||
self.max_context_len_to_capture = self.max_input_len + self.max_output_len
|
||||
@@ -260,6 +268,20 @@ class InferenceConfig(RPC_PARAM):
|
||||
if self.dtype == torch.float32:
|
||||
self.high_precision = False
|
||||
|
||||
# check StreamingLLM
|
||||
assert (
|
||||
self.start_token_size <= self.block_size
|
||||
), f"According to the paper https://arxiv.org/pdf/2309.17453, the start_token_size greater than 4 has little impact on inference performance. Therefore, we assume that the start_token_size should be less or equal than the block_size={self.block_size}, but got {self.start_token_size}."
|
||||
assert (
|
||||
self.generated_token_size % self.block_size == 0
|
||||
), f"We assume that the generated_token_size should be a multiple of the block_size, got generated_token_size={self.generated_token_size}."
|
||||
# Our StreamingLLM implementation (sliding window attention with attention sinks) references https://arxiv.org/pdf/2309.17453 and has been optimized
|
||||
# based on our framework's kvcache management mechanism. According to the paper, a start_token_size of 4 is sufficient. Therefore,
|
||||
# we assume the start_token_size is less than or equal to the block size. When the start_token_size is smaller than the block size,
|
||||
# we fill the first block with the start_token_size and subsequently generated tokens, using these as the "start tokens."
|
||||
# Thereafter, we swap out tokens in units of blocks, and always swapping out the second block when the generated tokens exceeded the limit.
|
||||
self.start_token_size = self.block_size
|
||||
|
||||
# check prompt template
|
||||
if self.prompt_template is None:
|
||||
return
|
||||
|
Reference in New Issue
Block a user