mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 13:57:46 +00:00
236 lines
8.5 KiB
Python
236 lines
8.5 KiB
Python
"""General prompt helper that can help deal with LLM context window token limitations.
|
|
|
|
At its core, it calculates available context size by starting with the context window
|
|
size of an LLM and reserve token space for the prompt template, and the output.
|
|
|
|
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
|
make use of the available context window (and thereby reducing the number of LLM calls
|
|
needed), or truncating them so that they fit in a single LLM call.
|
|
"""
|
|
|
|
import logging
|
|
from string import Formatter
|
|
from typing import Callable, List, Optional, Sequence, Set
|
|
|
|
from dbgpt._private.llm_metadata import LLMMetadata
|
|
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr, model_validator
|
|
from dbgpt.core.interface.prompt import get_template_vars
|
|
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
|
from dbgpt.util.global_helper import globals_helper
|
|
|
|
DEFAULT_PADDING = 5
|
|
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
|
|
|
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
|
|
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PromptHelper(BaseModel):
|
|
"""Prompt helper.
|
|
|
|
General prompt helper that can help deal with LLM context window token limitations.
|
|
|
|
At its core, it calculates available context size by starting with the context
|
|
window size of an LLM and reserve token space for the prompt template, and the
|
|
output.
|
|
|
|
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
|
make use of the available context window (and thereby reducing the number of LLM
|
|
calls needed), or truncating them so that they fit in a single LLM call.
|
|
|
|
Args:
|
|
context_window (int): Context window for the LLM.
|
|
num_output (int): Number of outputs for the LLM.
|
|
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
|
|
chunk_size_limit (Optional[int]): Maximum chunk size to use.
|
|
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
|
|
separator (str): Separator for text splitter
|
|
|
|
"""
|
|
|
|
context_window: int = Field(
|
|
default=DEFAULT_CONTEXT_WINDOW,
|
|
description="The maximum context size that will get sent to the LLM.",
|
|
)
|
|
num_output: int = Field(
|
|
default=DEFAULT_NUM_OUTPUTS,
|
|
description="The amount of token-space to leave in input for generation.",
|
|
)
|
|
chunk_overlap_ratio: float = Field(
|
|
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
|
description="The percentage token amount that each chunk should overlap.",
|
|
)
|
|
chunk_size_limit: Optional[int] = Field(
|
|
None, description="The maximum size of a chunk."
|
|
)
|
|
separator: str = Field(
|
|
default=" ", description="The separator when chunking tokens."
|
|
)
|
|
|
|
_tokenizer: Optional[Callable[[str], List]] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
|
num_output: int = DEFAULT_NUM_OUTPUTS,
|
|
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
|
chunk_size_limit: Optional[int] = None,
|
|
tokenizer: Optional[Callable[[str], List]] = None,
|
|
separator: str = " ",
|
|
**kwargs,
|
|
) -> None:
|
|
"""Init params."""
|
|
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
|
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
|
|
|
super().__init__(
|
|
context_window=context_window,
|
|
num_output=num_output,
|
|
chunk_overlap_ratio=chunk_overlap_ratio,
|
|
chunk_size_limit=chunk_size_limit,
|
|
separator=separator,
|
|
**kwargs,
|
|
)
|
|
# TODO: make configurable
|
|
self._tokenizer = tokenizer or globals_helper.tokenizer
|
|
|
|
def token_count(self, prompt_template: str) -> int:
|
|
"""Get token count of prompt template."""
|
|
empty_prompt_txt = get_empty_prompt_txt(prompt_template)
|
|
return len(self._tokenizer(empty_prompt_txt))
|
|
|
|
@classmethod
|
|
def from_llm_metadata(
|
|
cls,
|
|
llm_metadata: LLMMetadata,
|
|
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
|
chunk_size_limit: Optional[int] = None,
|
|
tokenizer: Optional[Callable[[str], List]] = None,
|
|
separator: str = " ",
|
|
) -> "PromptHelper":
|
|
"""Create from llm predictor.
|
|
|
|
This will autofill values like context_window and num_output.
|
|
|
|
"""
|
|
context_window = llm_metadata.context_window
|
|
if llm_metadata.num_output == -1:
|
|
num_output = DEFAULT_NUM_OUTPUTS
|
|
else:
|
|
num_output = llm_metadata.num_output
|
|
|
|
return cls(
|
|
context_window=context_window,
|
|
num_output=num_output,
|
|
chunk_overlap_ratio=chunk_overlap_ratio,
|
|
chunk_size_limit=chunk_size_limit,
|
|
tokenizer=tokenizer,
|
|
separator=separator,
|
|
)
|
|
|
|
@classmethod
|
|
def class_name(cls) -> str:
|
|
return "PromptHelper"
|
|
|
|
def _get_available_context_size(self, template: str) -> int:
|
|
"""Get available context size.
|
|
|
|
This is calculated as:
|
|
available context window = total context window
|
|
- input (partially filled prompt)
|
|
- output (room reserved for response)
|
|
|
|
Notes:
|
|
- Available context size is further clamped to be non-negative.
|
|
"""
|
|
empty_prompt_txt = get_empty_prompt_txt(template)
|
|
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
|
|
context_size_tokens = (
|
|
self.context_window - num_empty_prompt_tokens - self.num_output
|
|
)
|
|
if context_size_tokens < 0:
|
|
raise ValueError(
|
|
f"Calculated available context size {context_size_tokens} was"
|
|
" not non-negative."
|
|
)
|
|
return context_size_tokens
|
|
|
|
def _get_available_chunk_size(
|
|
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
|
|
) -> int:
|
|
"""Get available chunk size.
|
|
|
|
This is calculated as:
|
|
available chunk size = available context window // number_chunks
|
|
- padding
|
|
|
|
Notes:
|
|
- By default, we use padding of 5 (to save space for formatting needs).
|
|
- Available chunk size is further clamped to chunk_size_limit if specified.
|
|
"""
|
|
available_context_size = self._get_available_context_size(prompt_template)
|
|
result = available_context_size // num_chunks - padding
|
|
if self.chunk_size_limit is not None:
|
|
result = min(result, self.chunk_size_limit)
|
|
return result
|
|
|
|
def get_text_splitter_given_prompt(
|
|
self,
|
|
prompt_template: str,
|
|
num_chunks: int = 1,
|
|
padding: int = DEFAULT_PADDING,
|
|
) -> TokenTextSplitter:
|
|
"""Get text splitter configured to maximally pack available context window,
|
|
taking into account of given prompt, and desired number of chunks.
|
|
"""
|
|
chunk_size = self._get_available_chunk_size(
|
|
prompt_template, num_chunks, padding=padding
|
|
)
|
|
if chunk_size <= 0:
|
|
raise ValueError(f"Chunk size {chunk_size} is not positive.")
|
|
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
|
|
return TokenTextSplitter(
|
|
separator=self.separator,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
tokenizer=self._tokenizer,
|
|
)
|
|
|
|
def repack(
|
|
self,
|
|
prompt_template: str,
|
|
text_chunks: Sequence[str],
|
|
padding: int = DEFAULT_PADDING,
|
|
) -> List[str]:
|
|
"""Repack text chunks to fit available context window.
|
|
|
|
This will combine text chunks into consolidated chunks
|
|
that more fully "pack" the prompt template given the context_window.
|
|
|
|
"""
|
|
text_splitter = self.get_text_splitter_given_prompt(
|
|
prompt_template, padding=padding
|
|
)
|
|
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
|
|
return text_splitter.split_text(combined_str)
|
|
|
|
|
|
def get_empty_prompt_txt(template: str) -> str:
|
|
"""Get empty prompt text.
|
|
|
|
Substitute empty strings in parts of the prompt that have
|
|
not yet been filled out. Skip variables that have already
|
|
been partially formatted. This is used to compute the initial tokens.
|
|
|
|
"""
|
|
# partial_kargs = prompt.kwargs
|
|
|
|
partial_kargs = {}
|
|
template_vars = get_template_vars(template)
|
|
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
|
|
all_kwargs = {**partial_kargs, **empty_kwargs}
|
|
prompt = template.format(**all_kwargs)
|
|
return prompt
|