mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
feat:add document summary
This commit is contained in:
@@ -21,35 +21,27 @@ async def llm_chat_response(chat_scene: str, **chat_param):
|
|||||||
return chat.stream_call()
|
return chat.stream_call()
|
||||||
|
|
||||||
|
|
||||||
def run_async_tasks(
|
async def run_async_tasks(
|
||||||
tasks: List[Coroutine],
|
tasks: List[Coroutine],
|
||||||
show_progress: bool = False,
|
concurrency_limit: int = None,
|
||||||
progress_bar_desc: str = "Running async tasks",
|
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""Run a list of async tasks."""
|
"""Run a list of async tasks."""
|
||||||
|
|
||||||
tasks_to_execute: List[Any] = tasks
|
tasks_to_execute: List[Any] = tasks
|
||||||
if show_progress:
|
|
||||||
try:
|
|
||||||
import nest_asyncio
|
|
||||||
from tqdm.asyncio import tqdm
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
async def _tqdm_gather() -> List[Any]:
|
|
||||||
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
|
|
||||||
|
|
||||||
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
|
|
||||||
return tqdm_outputs
|
|
||||||
# run the operation w/o tqdm on hitting a fatal
|
|
||||||
# may occur in some environments where tqdm.asyncio
|
|
||||||
# is not supported
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _gather() -> List[Any]:
|
async def _gather() -> List[Any]:
|
||||||
return await asyncio.gather(*tasks_to_execute)
|
if concurrency_limit:
|
||||||
|
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||||
|
|
||||||
outputs: List[Any] = asyncio.run(_gather())
|
async def _execute_task(task):
|
||||||
return outputs
|
async with semaphore:
|
||||||
|
return await task
|
||||||
|
|
||||||
|
# Execute tasks with semaphore limit
|
||||||
|
return await asyncio.gather(
|
||||||
|
*[_execute_task(task) for task in tasks_to_execute]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await asyncio.gather(*tasks_to_execute)
|
||||||
|
|
||||||
|
# outputs: List[Any] = asyncio.run(_gather())
|
||||||
|
return await _gather()
|
||||||
|
448
pilot/common/global_helper.py
Normal file
448
pilot/common/global_helper.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
"""General utils functions."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial, wraps
|
||||||
|
from itertools import islice
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalsHelper:
|
||||||
|
"""Helper to retrieve globals.
|
||||||
|
|
||||||
|
Helpful for global caching of certain variables that can be expensive to load.
|
||||||
|
(e.g. tokenization)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
_tokenizer: Optional[Callable[[str], List]] = None
|
||||||
|
_stopwords: Optional[List[str]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokenizer(self) -> Callable[[str], List]:
|
||||||
|
"""Get tokenizer."""
|
||||||
|
if self._tokenizer is None:
|
||||||
|
tiktoken_import_err = (
|
||||||
|
"`tiktoken` package not found, please run `pip install tiktoken`"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(tiktoken_import_err)
|
||||||
|
enc = tiktoken.get_encoding("gpt2")
|
||||||
|
self._tokenizer = cast(Callable[[str], List], enc.encode)
|
||||||
|
self._tokenizer = partial(self._tokenizer, allowed_special="all")
|
||||||
|
return self._tokenizer # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stopwords(self) -> List[str]:
|
||||||
|
"""Get stopwords."""
|
||||||
|
if self._stopwords is None:
|
||||||
|
try:
|
||||||
|
import nltk
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"`nltk` package not found, please run `pip install nltk`"
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_index.utils import get_cache_dir
|
||||||
|
|
||||||
|
cache_dir = get_cache_dir()
|
||||||
|
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||||
|
|
||||||
|
# update nltk path for nltk so that it finds the data
|
||||||
|
if nltk_data_dir not in nltk.data.path:
|
||||||
|
nltk.data.path.append(nltk_data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nltk.data.find("corpora/stopwords")
|
||||||
|
except LookupError:
|
||||||
|
nltk.download("stopwords", download_dir=nltk_data_dir)
|
||||||
|
self._stopwords = stopwords.words("english")
|
||||||
|
return self._stopwords
|
||||||
|
|
||||||
|
|
||||||
|
globals_helper = GlobalsHelper()
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_id(d: Set) -> str:
|
||||||
|
"""Get a new ID."""
|
||||||
|
while True:
|
||||||
|
new_id = str(uuid.uuid4())
|
||||||
|
if new_id not in d:
|
||||||
|
break
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_int_id(d: Set) -> int:
|
||||||
|
"""Get a new integer ID."""
|
||||||
|
while True:
|
||||||
|
new_id = random.randint(0, sys.maxsize)
|
||||||
|
if new_id not in d:
|
||||||
|
break
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
|
||||||
|
"""Temporary setter.
|
||||||
|
|
||||||
|
Utility class for setting a temporary value for an attribute on a class.
|
||||||
|
Taken from: https://tinyurl.com/2p89xymh
|
||||||
|
|
||||||
|
"""
|
||||||
|
prev_values = {k: getattr(obj, k) for k in kwargs}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(obj, k, v)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for k, v in prev_values.items():
|
||||||
|
setattr(obj, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ErrorToRetry:
|
||||||
|
"""Exception types that should be retried.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exception_cls (Type[Exception]): Class of exception.
|
||||||
|
check_fn (Optional[Callable[[Any]], bool]]):
|
||||||
|
A function that takes an exception instance as input and returns
|
||||||
|
whether to retry.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
exception_cls: Type[Exception]
|
||||||
|
check_fn: Optional[Callable[[Any], bool]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_exceptions_with_backoff(
|
||||||
|
lambda_fn: Callable,
|
||||||
|
errors_to_retry: List[ErrorToRetry],
|
||||||
|
max_tries: int = 10,
|
||||||
|
min_backoff_secs: float = 0.5,
|
||||||
|
max_backoff_secs: float = 60.0,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute lambda function with retries and exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lambda_fn (Callable): Function to be called and output we want.
|
||||||
|
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
|
||||||
|
At least one needs to be provided.
|
||||||
|
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
|
||||||
|
min_backoff_secs (float): Minimum amount of backoff time between attempts.
|
||||||
|
Defaults to 0.5.
|
||||||
|
max_backoff_secs (float): Maximum amount of backoff time between attempts.
|
||||||
|
Defaults to 60.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not errors_to_retry:
|
||||||
|
raise ValueError("At least one error to retry needs to be provided")
|
||||||
|
|
||||||
|
error_checks = {
|
||||||
|
error_to_retry.exception_cls: error_to_retry.check_fn
|
||||||
|
for error_to_retry in errors_to_retry
|
||||||
|
}
|
||||||
|
exception_class_tuples = tuple(error_checks.keys())
|
||||||
|
|
||||||
|
backoff_secs = min_backoff_secs
|
||||||
|
tries = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return lambda_fn()
|
||||||
|
except exception_class_tuples as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
tries += 1
|
||||||
|
if tries >= max_tries:
|
||||||
|
raise
|
||||||
|
check_fn = error_checks.get(e.__class__)
|
||||||
|
if check_fn and not check_fn(e):
|
||||||
|
raise
|
||||||
|
time.sleep(backoff_secs)
|
||||||
|
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_text(text: str, max_length: int) -> str:
|
||||||
|
"""Truncate text to a maximum length."""
|
||||||
|
if len(text) <= max_length:
|
||||||
|
return text
|
||||||
|
return text[: max_length - 3] + "..."
|
||||||
|
|
||||||
|
|
||||||
|
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
||||||
|
"""Iterate over an iterable in batches.
|
||||||
|
|
||||||
|
>>> list(iter_batch([1,2,3,4,5], 3))
|
||||||
|
[[1, 2, 3], [4, 5]]
|
||||||
|
"""
|
||||||
|
source_iter = iter(iterable)
|
||||||
|
while source_iter:
|
||||||
|
b = list(islice(source_iter, size))
|
||||||
|
if len(b) == 0:
|
||||||
|
break
|
||||||
|
yield b
|
||||||
|
|
||||||
|
|
||||||
|
def concat_dirs(dirname: str, basename: str) -> str:
|
||||||
|
"""
|
||||||
|
Append basename to dirname, avoiding backslashes when running on windows.
|
||||||
|
|
||||||
|
os.path.join(dirname, basename) will add a backslash before dirname if
|
||||||
|
basename does not end with a slash, so we make sure it does.
|
||||||
|
"""
|
||||||
|
dirname += "/" if dirname[-1] != "/" else ""
|
||||||
|
return os.path.join(dirname, basename)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
|
||||||
|
"""
|
||||||
|
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
|
||||||
|
"""
|
||||||
|
_iterator = items
|
||||||
|
if show_progress:
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
return tqdm(items, desc=desc)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
return _iterator
|
||||||
|
|
||||||
|
|
||||||
|
def count_tokens(text: str) -> int:
|
||||||
|
tokens = globals_helper.tokenizer(text)
|
||||||
|
return len(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_name(str): the model name of the tokenizer.
|
||||||
|
For instance, fxmarty/tiny-llama-fast-tokenizer.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"`transformers` package not found, please run `pip install transformers`"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
return tokenizer.tokenize
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_dir() -> str:
|
||||||
|
"""Locate a platform-appropriate cache directory for llama_index,
|
||||||
|
and create it if it doesn't yet exist.
|
||||||
|
"""
|
||||||
|
# User override
|
||||||
|
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
|
||||||
|
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
|
||||||
|
|
||||||
|
# Linux, Unix, AIX, etc.
|
||||||
|
elif os.name == "posix" and sys.platform != "darwin":
|
||||||
|
path = Path("/tmp/llama_index")
|
||||||
|
|
||||||
|
# Mac OS
|
||||||
|
elif sys.platform == "darwin":
|
||||||
|
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
|
||||||
|
|
||||||
|
# Windows (hopefully)
|
||||||
|
else:
|
||||||
|
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
|
||||||
|
"~\\AppData\\Local"
|
||||||
|
)
|
||||||
|
path = Path(local, "llama_index")
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(
|
||||||
|
path, exist_ok=True
|
||||||
|
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
def add_sync_version(func: Any) -> Any:
|
||||||
|
"""Decorator for adding sync version of an async function. The sync version
|
||||||
|
is added as a function attribute to the original function, func.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func(Any): the async function for which a sync variant will be built.
|
||||||
|
"""
|
||||||
|
assert asyncio.iscoroutinefunction(func)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def _wrapper(*args: Any, **kwds: Any) -> Any:
|
||||||
|
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
|
||||||
|
|
||||||
|
func.sync = _wrapper
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
# Sample text from llama_index's readme
|
||||||
|
SAMPLE_TEXT = """
|
||||||
|
Context
|
||||||
|
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
|
||||||
|
They are pre-trained on large amounts of publicly available data.
|
||||||
|
How do we best augment LLMs with our own private data?
|
||||||
|
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
|
||||||
|
|
||||||
|
Proposed Solution
|
||||||
|
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
|
||||||
|
you build LLM apps. It provides the following tools:
|
||||||
|
|
||||||
|
Offers data connectors to ingest your existing data sources and data formats
|
||||||
|
(APIs, PDFs, docs, SQL, etc.)
|
||||||
|
Provides ways to structure your data (indices, graphs) so that this data can be
|
||||||
|
easily used with LLMs.
|
||||||
|
Provides an advanced retrieval/query interface over your data:
|
||||||
|
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
|
||||||
|
Allows easy integrations with your outer application framework
|
||||||
|
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
|
||||||
|
LlamaIndex provides tools for both beginner users and advanced users.
|
||||||
|
Our high-level API allows beginner users to use LlamaIndex to ingest and
|
||||||
|
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
|
||||||
|
customize and extend any module (data connectors, indices, retrievers, query engines,
|
||||||
|
reranking modules), to fit their needs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_LLAMA_INDEX_COLORS = {
|
||||||
|
"llama_pink": "38;2;237;90;200",
|
||||||
|
"llama_blue": "38;2;90;149;237",
|
||||||
|
"llama_turquoise": "38;2;11;159;203",
|
||||||
|
"llama_lavender": "38;2;155;135;227",
|
||||||
|
}
|
||||||
|
|
||||||
|
_ANSI_COLORS = {
|
||||||
|
"red": "31",
|
||||||
|
"green": "32",
|
||||||
|
"yellow": "33",
|
||||||
|
"blue": "34",
|
||||||
|
"magenta": "35",
|
||||||
|
"cyan": "36",
|
||||||
|
"pink": "38;5;200",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_color_mapping(
|
||||||
|
items: List[str], use_llama_index_colors: bool = True
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get a mapping of items to colors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (List[str]): List of items to be mapped to colors.
|
||||||
|
use_llama_index_colors (bool, optional): Flag to indicate
|
||||||
|
whether to use LlamaIndex colors or ANSI colors.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: Mapping of items to colors.
|
||||||
|
"""
|
||||||
|
if use_llama_index_colors:
|
||||||
|
color_palette = _LLAMA_INDEX_COLORS
|
||||||
|
else:
|
||||||
|
color_palette = _ANSI_COLORS
|
||||||
|
|
||||||
|
colors = list(color_palette.keys())
|
||||||
|
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_colored_text(text: str, color: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the colored version of the input text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Input text.
|
||||||
|
color (str): Color to be applied to the text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Colored version of the input text.
|
||||||
|
"""
|
||||||
|
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
|
||||||
|
|
||||||
|
if color not in all_colors:
|
||||||
|
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
|
||||||
|
|
||||||
|
color = all_colors[color]
|
||||||
|
|
||||||
|
return f"\033[1;3;{color}m{text}\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||||
|
"""
|
||||||
|
Print the text with the specified color.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): Text to be printed.
|
||||||
|
color (str, optional): Color to be applied to the text. Supported colors are:
|
||||||
|
llama_pink, llama_blue, llama_turquoise, llama_lavender,
|
||||||
|
red, green, yellow, blue, magenta, cyan, pink.
|
||||||
|
end (str, optional): String appended after the last character of the text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
text_to_print = _get_colored_text(text, color) if color is not None else text
|
||||||
|
print(text_to_print, end=end)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_torch_device() -> str:
|
||||||
|
"""Infer the input to torch.device."""
|
||||||
|
try:
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
except NameError:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
has_cuda = torch.cuda.is_available()
|
||||||
|
if has_cuda:
|
||||||
|
return "cuda"
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def unit_generator(x: Any) -> Generator[Any, None, None]:
|
||||||
|
"""A function that returns a generator of a single element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): the element to build yield
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Any: the single element
|
||||||
|
"""
|
||||||
|
yield x
|
||||||
|
|
||||||
|
|
||||||
|
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
|
||||||
|
"""A function that returns a generator of a single element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Any): the element to build yield
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Any: the single element
|
||||||
|
"""
|
||||||
|
yield x
|
43
pilot/common/llm_metadata.py
Normal file
43
pilot/common/llm_metadata.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
DEFAULT_CONTEXT_WINDOW = 3900
|
||||||
|
DEFAULT_NUM_OUTPUTS = 256
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMetadata(BaseModel):
|
||||||
|
context_window: int = Field(
|
||||||
|
default=DEFAULT_CONTEXT_WINDOW,
|
||||||
|
description=(
|
||||||
|
"Total number of tokens the model can be input and output for one response."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
num_output: int = Field(
|
||||||
|
default=DEFAULT_NUM_OUTPUTS,
|
||||||
|
description="Number of tokens the model can output when generating a response.",
|
||||||
|
)
|
||||||
|
is_chat_model: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Set True if the model exposes a chat interface (i.e. can be passed a"
|
||||||
|
" sequence of messages, rather than text), like OpenAI's"
|
||||||
|
" /v1/chat/completions endpoint."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
is_function_calling_model: bool = Field(
|
||||||
|
default=False,
|
||||||
|
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
|
||||||
|
description=(
|
||||||
|
"Set True if the model supports function calling messages, similar to"
|
||||||
|
" OpenAI's function calling API. For example, converting 'Email Anya to"
|
||||||
|
" see if she wants to get coffee next Friday' to a function call like"
|
||||||
|
" `send_email(to: string, body: string)`."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_name: str = Field(
|
||||||
|
default="unknown",
|
||||||
|
description=(
|
||||||
|
"The model's name used for logging, testing, and sanity checking. For some"
|
||||||
|
" models this can be automatically discerned. For other models, like"
|
||||||
|
" locally loaded models, this must be manually specified."
|
||||||
|
),
|
||||||
|
)
|
254
pilot/common/prompt_util.py
Normal file
254
pilot/common/prompt_util.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
"""General prompt helper that can help deal with LLM context window token limitations.
|
||||||
|
|
||||||
|
At its core, it calculates available context size by starting with the context window
|
||||||
|
size of an LLM and reserve token space for the prompt template, and the output.
|
||||||
|
|
||||||
|
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||||
|
make use of the available context window (and thereby reducing the number of LLM calls
|
||||||
|
needed), or truncating them so that they fit in a single LLM call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from string import Formatter
|
||||||
|
from typing import Callable, List, Optional, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field, PrivateAttr, BaseModel
|
||||||
|
from llama_index.prompts import BasePromptTemplate
|
||||||
|
|
||||||
|
from pilot.common.global_helper import globals_helper
|
||||||
|
from pilot.common.llm_metadata import LLMMetadata
|
||||||
|
from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter
|
||||||
|
|
||||||
|
DEFAULT_PADDING = 5
|
||||||
|
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
||||||
|
|
||||||
|
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
|
||||||
|
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptHelper(BaseModel):
|
||||||
|
"""Prompt helper.
|
||||||
|
|
||||||
|
General prompt helper that can help deal with LLM context window token limitations.
|
||||||
|
|
||||||
|
At its core, it calculates available context size by starting with the context
|
||||||
|
window size of an LLM and reserve token space for the prompt template, and the
|
||||||
|
output.
|
||||||
|
|
||||||
|
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||||
|
make use of the available context window (and thereby reducing the number of LLM
|
||||||
|
calls needed), or truncating them so that they fit in a single LLM call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_window (int): Context window for the LLM.
|
||||||
|
num_output (int): Number of outputs for the LLM.
|
||||||
|
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
|
||||||
|
chunk_size_limit (Optional[int]): Maximum chunk size to use.
|
||||||
|
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
|
||||||
|
separator (str): Separator for text splitter
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
context_window: int = Field(
|
||||||
|
default=DEFAULT_CONTEXT_WINDOW,
|
||||||
|
description="The maximum context size that will get sent to the LLM.",
|
||||||
|
)
|
||||||
|
num_output: int = Field(
|
||||||
|
default=DEFAULT_NUM_OUTPUTS,
|
||||||
|
description="The amount of token-space to leave in input for generation.",
|
||||||
|
)
|
||||||
|
chunk_overlap_ratio: float = Field(
|
||||||
|
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||||
|
description="The percentage token amount that each chunk should overlap.",
|
||||||
|
)
|
||||||
|
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
|
||||||
|
separator: str = Field(
|
||||||
|
default=" ", description="The separator when chunking tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
_tokenizer: Callable[[str], List] = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||||
|
num_output: int = DEFAULT_NUM_OUTPUTS,
|
||||||
|
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||||
|
chunk_size_limit: Optional[int] = None,
|
||||||
|
tokenizer: Optional[Callable[[str], List]] = None,
|
||||||
|
separator: str = " ",
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
||||||
|
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
||||||
|
|
||||||
|
# TODO: make configurable
|
||||||
|
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
context_window=context_window,
|
||||||
|
num_output=num_output,
|
||||||
|
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||||
|
chunk_size_limit=chunk_size_limit,
|
||||||
|
separator=separator,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_metadata(
|
||||||
|
cls,
|
||||||
|
llm_metadata: LLMMetadata,
|
||||||
|
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||||
|
chunk_size_limit: Optional[int] = None,
|
||||||
|
tokenizer: Optional[Callable[[str], List]] = None,
|
||||||
|
separator: str = " ",
|
||||||
|
) -> "PromptHelper":
|
||||||
|
"""Create from llm predictor.
|
||||||
|
|
||||||
|
This will autofill values like context_window and num_output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context_window = llm_metadata.context_window
|
||||||
|
if llm_metadata.num_output == -1:
|
||||||
|
num_output = DEFAULT_NUM_OUTPUTS
|
||||||
|
else:
|
||||||
|
num_output = llm_metadata.num_output
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
context_window=context_window,
|
||||||
|
num_output=num_output,
|
||||||
|
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||||
|
chunk_size_limit=chunk_size_limit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
separator=separator,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "PromptHelper"
|
||||||
|
|
||||||
|
def _get_available_context_size(self, template: str) -> int:
|
||||||
|
"""Get available context size.
|
||||||
|
|
||||||
|
This is calculated as:
|
||||||
|
available context window = total context window
|
||||||
|
- input (partially filled prompt)
|
||||||
|
- output (room reserved for response)
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Available context size is further clamped to be non-negative.
|
||||||
|
"""
|
||||||
|
empty_prompt_txt = get_empty_prompt_txt(template)
|
||||||
|
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
|
||||||
|
context_size_tokens = (
|
||||||
|
self.context_window - num_empty_prompt_tokens - self.num_output
|
||||||
|
)
|
||||||
|
if context_size_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Calculated available context size {context_size_tokens} was"
|
||||||
|
" not non-negative."
|
||||||
|
)
|
||||||
|
return context_size_tokens
|
||||||
|
|
||||||
|
def _get_available_chunk_size(
|
||||||
|
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
|
||||||
|
) -> int:
|
||||||
|
"""Get available chunk size.
|
||||||
|
|
||||||
|
This is calculated as:
|
||||||
|
available chunk size = available context window // number_chunks
|
||||||
|
- padding
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- By default, we use padding of 5 (to save space for formatting needs).
|
||||||
|
- Available chunk size is further clamped to chunk_size_limit if specified.
|
||||||
|
"""
|
||||||
|
available_context_size = self._get_available_context_size(prompt_template)
|
||||||
|
result = available_context_size // num_chunks - padding
|
||||||
|
if self.chunk_size_limit is not None:
|
||||||
|
result = min(result, self.chunk_size_limit)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_text_splitter_given_prompt(
|
||||||
|
self,
|
||||||
|
prompt_template: str,
|
||||||
|
num_chunks: int = 1,
|
||||||
|
padding: int = DEFAULT_PADDING,
|
||||||
|
) -> TokenTextSplitter:
|
||||||
|
"""Get text splitter configured to maximally pack available context window,
|
||||||
|
taking into account of given prompt, and desired number of chunks.
|
||||||
|
"""
|
||||||
|
chunk_size = self._get_available_chunk_size(
|
||||||
|
prompt_template, num_chunks, padding=padding
|
||||||
|
)
|
||||||
|
if chunk_size <= 0:
|
||||||
|
raise ValueError(f"Chunk size {chunk_size} is not positive.")
|
||||||
|
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
|
||||||
|
return TokenTextSplitter(
|
||||||
|
separator=self.separator,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# def truncate(
|
||||||
|
# self,
|
||||||
|
# prompt: BasePromptTemplate,
|
||||||
|
# text_chunks: Sequence[str],
|
||||||
|
# padding: int = DEFAULT_PADDING,
|
||||||
|
# ) -> List[str]:
|
||||||
|
# """Truncate text chunks to fit available context window."""
|
||||||
|
# text_splitter = self.get_text_splitter_given_prompt(
|
||||||
|
# prompt,
|
||||||
|
# num_chunks=len(text_chunks),
|
||||||
|
# padding=padding,
|
||||||
|
# )
|
||||||
|
# return [truncate_text(chunk, text_splitter) for chunk in text_chunks]
|
||||||
|
|
||||||
|
def repack(
|
||||||
|
self,
|
||||||
|
prompt_template: str,
|
||||||
|
text_chunks: Sequence[str],
|
||||||
|
padding: int = DEFAULT_PADDING,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Repack text chunks to fit available context window.
|
||||||
|
|
||||||
|
This will combine text chunks into consolidated chunks
|
||||||
|
that more fully "pack" the prompt template given the context_window.
|
||||||
|
|
||||||
|
"""
|
||||||
|
text_splitter = self.get_text_splitter_given_prompt(
|
||||||
|
prompt_template, padding=padding
|
||||||
|
)
|
||||||
|
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
|
||||||
|
return text_splitter.split_text(combined_str)
|
||||||
|
|
||||||
|
|
||||||
|
def get_empty_prompt_txt(template: str) -> str:
|
||||||
|
"""Get empty prompt text.
|
||||||
|
|
||||||
|
Substitute empty strings in parts of the prompt that have
|
||||||
|
not yet been filled out. Skip variables that have already
|
||||||
|
been partially formatted. This is used to compute the initial tokens.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# partial_kargs = prompt.kwargs
|
||||||
|
|
||||||
|
partial_kargs = {}
|
||||||
|
template_vars = get_template_vars(template)
|
||||||
|
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
|
||||||
|
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||||
|
prompt = template.format(**all_kwargs)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_vars(template_str: str) -> List[str]:
|
||||||
|
"""Get template variables from a template string."""
|
||||||
|
variables = []
|
||||||
|
formatter = Formatter()
|
||||||
|
|
||||||
|
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||||
|
if variable_name:
|
||||||
|
variables.append(variable_name)
|
||||||
|
|
||||||
|
return variables
|
82
pilot/embedding_engine/loader/splitter_utils.py
Normal file
82
pilot/embedding_engine/loader/splitter_utils.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
|
||||||
|
def split_text_keep_separator(text: str, separator: str) -> List[str]:
|
||||||
|
"""Split text with separator and keep the separator at the end of each split."""
|
||||||
|
parts = text.split(separator)
|
||||||
|
result = [separator + s if i > 0 else s for i, s in enumerate(parts)]
|
||||||
|
return [s for s in result if s]
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_sep(sep: str, keep_sep: bool = True) -> Callable[[str], List[str]]:
|
||||||
|
"""Split text by separator."""
|
||||||
|
if keep_sep:
|
||||||
|
return lambda text: split_text_keep_separator(text, sep)
|
||||||
|
else:
|
||||||
|
return lambda text: text.split(sep)
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_char() -> Callable[[str], List[str]]:
|
||||||
|
"""Split text by character."""
|
||||||
|
return lambda text: list(text)
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_sentence_tokenizer() -> Callable[[str], List[str]]:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
from llama_index.utils import get_cache_dir
|
||||||
|
|
||||||
|
cache_dir = get_cache_dir()
|
||||||
|
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||||
|
|
||||||
|
# update nltk path for nltk so that it finds the data
|
||||||
|
if nltk_data_dir not in nltk.data.path:
|
||||||
|
nltk.data.path.append(nltk_data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nltk.data.find("tokenizers/punkt")
|
||||||
|
except LookupError:
|
||||||
|
nltk.download("punkt", download_dir=nltk_data_dir)
|
||||||
|
|
||||||
|
tokenizer = nltk.tokenize.PunktSentenceTokenizer()
|
||||||
|
|
||||||
|
# get the spans and then return the sentences
|
||||||
|
# using the start index of each span
|
||||||
|
# instead of using end, use the start of the next span if available
|
||||||
|
def split(text: str) -> List[str]:
|
||||||
|
spans = list(tokenizer.span_tokenize(text))
|
||||||
|
sentences = []
|
||||||
|
for i, span in enumerate(spans):
|
||||||
|
start = span[0]
|
||||||
|
if i < len(spans) - 1:
|
||||||
|
end = spans[i + 1][0]
|
||||||
|
else:
|
||||||
|
end = len(text)
|
||||||
|
sentences.append(text[start:end])
|
||||||
|
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
return split
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_regex(regex: str) -> Callable[[str], List[str]]:
|
||||||
|
"""Split text by regex."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
return lambda text: re.findall(regex, text)
|
||||||
|
|
||||||
|
|
||||||
|
def split_by_phrase_regex() -> Callable[[str], List[str]]:
|
||||||
|
"""Split text by phrase regex.
|
||||||
|
|
||||||
|
This regular expression will split the sentences into phrases,
|
||||||
|
where each phrase is a sequence of one or more non-comma,
|
||||||
|
non-period, and non-semicolon characters, followed by an optional comma,
|
||||||
|
period, or semicolon. The regular expression will also capture the
|
||||||
|
delimiters themselves as separate items in the list of phrases.
|
||||||
|
"""
|
||||||
|
regex = "[^,.;。]+[,.;。]?"
|
||||||
|
return split_by_regex(regex)
|
||||||
|
|
184
pilot/embedding_engine/loader/token_splitter.py
Normal file
184
pilot/embedding_engine/loader/token_splitter.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Token splitter."""
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field, PrivateAttr, BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
from pilot.common.global_helper import globals_helper
|
||||||
|
from pilot.embedding_engine.loader.splitter_utils import split_by_sep, split_by_char
|
||||||
|
|
||||||
|
DEFAULT_METADATA_FORMAT_LEN = 2
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 20
|
||||||
|
DEFAULT_CHUNK_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
class TokenTextSplitter(BaseModel):
|
||||||
|
"""Implementation of splitting text that looks at word tokens."""
|
||||||
|
|
||||||
|
chunk_size: int = Field(
|
||||||
|
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
|
||||||
|
)
|
||||||
|
chunk_overlap: int = Field(
|
||||||
|
default=DEFAULT_CHUNK_OVERLAP,
|
||||||
|
description="The token overlap of each chunk when splitting.",
|
||||||
|
)
|
||||||
|
separator: str = Field(
|
||||||
|
default=" ", description="Default separator for splitting into words"
|
||||||
|
)
|
||||||
|
backup_separators: List = Field(
|
||||||
|
default_factory=list, description="Additional separators for splitting."
|
||||||
|
)
|
||||||
|
# callback_manager: CallbackManager = Field(
|
||||||
|
# default_factory=CallbackManager, exclude=True
|
||||||
|
# )
|
||||||
|
tokenizer: Callable = Field(
|
||||||
|
default_factory=globals_helper.tokenizer, # type: ignore
|
||||||
|
description="Tokenizer for splitting words into tokens.",
|
||||||
|
exclude=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_split_fns: List[Callable] = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||||
|
tokenizer: Optional[Callable] = None,
|
||||||
|
# callback_manager: Optional[CallbackManager] = None,
|
||||||
|
separator: str = " ",
|
||||||
|
backup_separators: Optional[List[str]] = ["\n"],
|
||||||
|
):
|
||||||
|
"""Initialize with parameters."""
|
||||||
|
if chunk_overlap > chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||||
|
f"({chunk_size}), should be smaller."
|
||||||
|
)
|
||||||
|
# callback_manager = callback_manager or CallbackManager([])
|
||||||
|
tokenizer = tokenizer or globals_helper.tokenizer
|
||||||
|
|
||||||
|
all_seps = [separator] + (backup_separators or [])
|
||||||
|
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=separator,
|
||||||
|
backup_separators=backup_separators,
|
||||||
|
# callback_manager=callback_manager,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "TokenTextSplitter"
|
||||||
|
|
||||||
|
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
||||||
|
"""Split text into chunks, reserving space required for metadata str."""
|
||||||
|
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
|
||||||
|
effective_chunk_size = self.chunk_size - metadata_len
|
||||||
|
if effective_chunk_size <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Metadata length ({metadata_len}) is longer than chunk size "
|
||||||
|
f"({self.chunk_size}). Consider increasing the chunk size or "
|
||||||
|
"decreasing the size of your metadata to avoid this."
|
||||||
|
)
|
||||||
|
elif effective_chunk_size < 50:
|
||||||
|
print(
|
||||||
|
f"Metadata length ({metadata_len}) is close to chunk size "
|
||||||
|
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
|
||||||
|
"Consider increasing the chunk size or decreasing the size of "
|
||||||
|
"your metadata to avoid this.",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._split_text(text, chunk_size=effective_chunk_size)
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
"""Split text into chunks."""
|
||||||
|
return self._split_text(text, chunk_size=self.chunk_size)
|
||||||
|
|
||||||
|
def _split_text(self, text: str, chunk_size: int) -> List[str]:
|
||||||
|
"""Split text into chunks up to chunk_size."""
|
||||||
|
if text == "":
|
||||||
|
return []
|
||||||
|
|
||||||
|
splits = self._split(text, chunk_size)
|
||||||
|
chunks = self._merge(splits, chunk_size)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split(self, text: str, chunk_size: int) -> List[str]:
|
||||||
|
"""Break text into splits that are smaller than chunk size.
|
||||||
|
|
||||||
|
The order of splitting is:
|
||||||
|
1. split by separator
|
||||||
|
2. split by backup separators (if any)
|
||||||
|
3. split by characters
|
||||||
|
|
||||||
|
NOTE: the splits contain the separators.
|
||||||
|
"""
|
||||||
|
if len(self.tokenizer(text)) <= chunk_size:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
for split_fn in self._split_fns:
|
||||||
|
splits = split_fn(text)
|
||||||
|
if len(splits) > 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
new_splits = []
|
||||||
|
for split in splits:
|
||||||
|
split_len = len(self.tokenizer(split))
|
||||||
|
if split_len <= chunk_size:
|
||||||
|
new_splits.append(split)
|
||||||
|
else:
|
||||||
|
# recursively split
|
||||||
|
new_splits.extend(self._split(split, chunk_size=chunk_size))
|
||||||
|
return new_splits
|
||||||
|
|
||||||
|
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
|
||||||
|
"""Merge splits into chunks.
|
||||||
|
|
||||||
|
The high-level idea is to keep adding splits to a chunk until we
|
||||||
|
exceed the chunk size, then we start a new chunk with overlap.
|
||||||
|
|
||||||
|
When we start a new chunk, we pop off the first element of the previous
|
||||||
|
chunk until the total length is less than the chunk size.
|
||||||
|
"""
|
||||||
|
chunks: List[str] = []
|
||||||
|
|
||||||
|
cur_chunk: List[str] = []
|
||||||
|
cur_len = 0
|
||||||
|
for split in splits:
|
||||||
|
split_len = len(self.tokenizer(split))
|
||||||
|
if split_len > chunk_size:
|
||||||
|
print(
|
||||||
|
f"Got a split of size {split_len}, ",
|
||||||
|
f"larger than chunk size {chunk_size}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# if we exceed the chunk size after adding the new split, then
|
||||||
|
# we need to end the current chunk and start a new one
|
||||||
|
if cur_len + split_len > chunk_size:
|
||||||
|
# end the previous chunk
|
||||||
|
chunk = "".join(cur_chunk).strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# start a new chunk with overlap
|
||||||
|
# keep popping off the first element of the previous chunk until:
|
||||||
|
# 1. the current chunk length is less than chunk overlap
|
||||||
|
# 2. the total length is less than chunk size
|
||||||
|
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
|
||||||
|
# pop off the first element
|
||||||
|
first_chunk = cur_chunk.pop(0)
|
||||||
|
cur_len -= len(self.tokenizer(first_chunk))
|
||||||
|
|
||||||
|
cur_chunk.append(split)
|
||||||
|
cur_len += split_len
|
||||||
|
|
||||||
|
# handle the last chunk
|
||||||
|
chunk = "".join(cur_chunk).strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
@@ -12,11 +12,12 @@ CFG = Config()
|
|||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """"""
|
PROMPT_SCENE_DEFINE = """"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
|
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请根据你之前推理的内容进行最终的总结,总结的时候可以详细点,回答的时候最好按照1.2.3.进行总结."""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_EN = """
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
|
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
|
||||||
\nWhen answering, it is best to summarize according to points 1.2.3.
|
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1, 2, and 3. and etc.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = (
|
_DEFAULT_TEMPLATE = (
|
||||||
@@ -27,7 +28,7 @@ PROMPT_RESPONSE = """"""
|
|||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ExtractRefineSummary.value(),
|
template_scene=ChatScene.ExtractRefineSummary.value(),
|
||||||
|
@@ -21,7 +21,8 @@ class ExtractSummary(BaseChat):
|
|||||||
chat_param=chat_param,
|
chat_param=chat_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.user_input = chat_param["current_user_input"]
|
# self.user_input = chat_param["current_user_input"]
|
||||||
|
self.user_input = chat_param["select_param"]
|
||||||
# self.extract_mode = chat_param["select_param"]
|
# self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
|
@@ -11,15 +11,15 @@ CFG = Config()
|
|||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """"""
|
PROMPT_SCENE_DEFINE = """"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
|
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
||||||
{context}
|
{context}
|
||||||
回答的时候最好按照1.2.3.点进行总结
|
答案尽量精确和简单,不要过长,长度控制在100字左右
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_EN = """
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
Write a summary of the following context:
|
Write a quick summary of the following context:
|
||||||
{context}
|
{context}
|
||||||
When answering, it is best to summarize according to points 1.2.3.
|
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = (
|
_DEFAULT_TEMPLATE = (
|
||||||
|
@@ -10,6 +10,7 @@ from pilot.configs.model_config import (
|
|||||||
EMBEDDING_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
)
|
)
|
||||||
|
from pilot.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
|
||||||
|
|
||||||
from pilot.openapi.api_view_model import Result
|
from pilot.openapi.api_view_model import Result
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
@@ -25,9 +26,11 @@ from pilot.server.knowledge.request.request import (
|
|||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
SpaceArgumentRequest,
|
SpaceArgumentRequest,
|
||||||
EntityExtractRequest,
|
EntityExtractRequest,
|
||||||
|
DocumentSummaryRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
|
from pilot.utils.tracer import root_tracer, SpanType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -201,6 +204,45 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
|||||||
return {"response": res}
|
return {"response": res}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/knowledge/document/summary")
|
||||||
|
async def document_summary(request: DocumentSummaryRequest):
|
||||||
|
print(f"/document/summary params: {request}")
|
||||||
|
try:
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"get_chat_instance", span_type=SpanType.CHAT, metadata=request
|
||||||
|
):
|
||||||
|
chat = await knowledge_space_service.document_summary(request=request)
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
}
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
if not chat.prompt_template.stream_out:
|
||||||
|
return StreamingResponse(
|
||||||
|
no_stream_generator(chat),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_generator(chat, False, request.model_name),
|
||||||
|
headers=headers,
|
||||||
|
media_type="text/plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
# return Result.succ(
|
||||||
|
# knowledge_space_service.create_knowledge_document(
|
||||||
|
# space=space_name, request=request
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# return Result.succ([])
|
||||||
|
except Exception as e:
|
||||||
|
return Result.faild(code="E000X", msg=f"document add error {e}")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/knowledge/entity/extract")
|
@router.post("/knowledge/entity/extract")
|
||||||
async def entity_extract(request: EntityExtractRequest):
|
async def entity_extract(request: EntityExtractRequest):
|
||||||
logger.info(f"Received params: {request}")
|
logger.info(f"Received params: {request}")
|
||||||
|
@@ -108,6 +108,14 @@ class SpaceArgumentRequest(BaseModel):
|
|||||||
argument: str
|
argument: str
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentSummaryRequest(BaseModel):
|
||||||
|
"""Sync request"""
|
||||||
|
|
||||||
|
"""doc_ids: doc ids"""
|
||||||
|
doc_id: int
|
||||||
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
class EntityExtractRequest(BaseModel):
|
class EntityExtractRequest(BaseModel):
|
||||||
"""argument: argument"""
|
"""argument: argument"""
|
||||||
|
|
||||||
|
@@ -10,7 +10,7 @@ from pilot.configs.model_config import (
|
|||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
)
|
)
|
||||||
from pilot.component import ComponentType
|
from pilot.component import ComponentType
|
||||||
from pilot.utils.executor_utils import ExecutorFactory
|
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
@@ -31,6 +31,7 @@ from pilot.server.knowledge.request.request import (
|
|||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
SpaceArgumentRequest,
|
SpaceArgumentRequest,
|
||||||
DocumentSyncRequest,
|
DocumentSyncRequest,
|
||||||
|
DocumentSummaryRequest,
|
||||||
)
|
)
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -303,7 +304,30 @@ class KnowledgeService:
|
|||||||
]
|
]
|
||||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||||
|
|
||||||
return True
|
return doc.id
|
||||||
|
|
||||||
|
async def document_summary(self, request: DocumentSummaryRequest):
|
||||||
|
"""get document summary
|
||||||
|
Args:
|
||||||
|
- request: DocumentSummaryRequest
|
||||||
|
"""
|
||||||
|
doc_query = KnowledgeDocumentEntity(id=request.doc_id)
|
||||||
|
documents = knowledge_document_dao.get_documents(doc_query)
|
||||||
|
if len(documents) != 1:
|
||||||
|
raise Exception(f"can not found document for {request.doc_id}")
|
||||||
|
document = documents[0]
|
||||||
|
query = DocumentChunkEntity(
|
||||||
|
document_id=request.doc_id,
|
||||||
|
)
|
||||||
|
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
|
||||||
|
if len(chunks) == 0:
|
||||||
|
raise Exception(f"can not found chunks for {request.doc_id}")
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
|
||||||
|
return await self.async_document_summary(
|
||||||
|
model_name=request.model_name, chunk_docs=chunk_docs, doc=document
|
||||||
|
)
|
||||||
|
|
||||||
def update_knowledge_space(
|
def update_knowledge_space(
|
||||||
self, space_id: int, space_request: KnowledgeSpaceRequest
|
self, space_id: int, space_request: KnowledgeSpaceRequest
|
||||||
@@ -417,30 +441,25 @@ class KnowledgeService:
|
|||||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||||
return knowledge_document_dao.update_knowledge_document(doc)
|
return knowledge_document_dao.update_knowledge_document(doc)
|
||||||
|
|
||||||
def async_document_summary(self, chunk_docs, doc):
|
async def async_document_summary(self, model_name, chunk_docs, doc):
|
||||||
"""async document extract summary
|
"""async document extract summary
|
||||||
Args:
|
Args:
|
||||||
|
- model_name: str
|
||||||
- chunk_docs: List[Document]
|
- chunk_docs: List[Document]
|
||||||
- doc: KnowledgeDocumentEntity
|
- doc: KnowledgeDocumentEntity
|
||||||
"""
|
"""
|
||||||
from llama_index import PromptHelper
|
|
||||||
from llama_index.prompts.default_prompt_selectors import (
|
|
||||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
texts = [doc.page_content for doc in chunk_docs]
|
texts = [doc.page_content for doc in chunk_docs]
|
||||||
prompt_helper = PromptHelper(context_window=2000)
|
from pilot.common.prompt_util import PromptHelper
|
||||||
|
|
||||||
texts = prompt_helper.repack(
|
prompt_helper = PromptHelper()
|
||||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
|
from pilot.scene.chat_knowledge.summary.prompt import prompt
|
||||||
)
|
|
||||||
|
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||||
)
|
)
|
||||||
summary = self._mapreduce_extract_summary(texts)
|
summary = await self._mapreduce_extract_summary(texts, model_name, 10, 3)
|
||||||
print(f"final summary:{summary}")
|
return await self._llm_extract_summary(summary, model_name)
|
||||||
doc.summary = summary
|
|
||||||
return knowledge_document_dao.update_knowledge_document(doc)
|
|
||||||
|
|
||||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||||
"""async document embedding into vector db
|
"""async document embedding into vector db
|
||||||
@@ -506,30 +525,37 @@ class KnowledgeService:
|
|||||||
return json.loads(spaces[0].context)
|
return json.loads(spaces[0].context)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _llm_extract_summary(self, doc: str):
|
async def _llm_extract_summary(self, doc: str, model_name: str = None):
|
||||||
"""Extract triplets from text by llm"""
|
"""Extract triplets from text by llm"""
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.chat_util import llm_chat_response_nostream
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": uuid.uuid1(),
|
"chat_session_id": uuid.uuid1(),
|
||||||
"current_user_input": doc,
|
"current_user_input": "",
|
||||||
"select_param": doc,
|
"select_param": doc,
|
||||||
"model_name": self.model_name,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
from pilot.common.chat_util import run_async_tasks
|
executor = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
from pilot.openapi.api_v1.api_v1 import CHAT_FACTORY
|
||||||
|
|
||||||
summary_iters = run_async_tasks(
|
chat = await blocking_func_to_async(
|
||||||
[
|
executor,
|
||||||
llm_chat_response_nostream(
|
CHAT_FACTORY.get_implementation,
|
||||||
ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
|
ChatScene.ExtractRefineSummary.value(),
|
||||||
)
|
**{"chat_param": chat_param},
|
||||||
]
|
|
||||||
)
|
)
|
||||||
return summary_iters[0]
|
return chat
|
||||||
|
|
||||||
def _mapreduce_extract_summary(self, docs):
|
async def _mapreduce_extract_summary(
|
||||||
|
self,
|
||||||
|
docs,
|
||||||
|
model_name: str = None,
|
||||||
|
max_iteration: int = 5,
|
||||||
|
concurrency_limit: int = None,
|
||||||
|
):
|
||||||
"""Extract summary by mapreduce mode
|
"""Extract summary by mapreduce mode
|
||||||
map -> multi async thread generate summary
|
map -> multi async thread generate summary
|
||||||
reduce -> merge the summaries by map process
|
reduce -> merge the summaries by map process
|
||||||
@@ -541,18 +567,16 @@ class KnowledgeService:
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
max_iteration = 5
|
|
||||||
if len(docs) == 1:
|
if len(docs) == 1:
|
||||||
summary = self._llm_extract_summary(doc=docs[0])
|
return docs[0]
|
||||||
return summary
|
|
||||||
else:
|
else:
|
||||||
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
||||||
for doc in docs[0:max_iteration]:
|
for doc in docs[0:max_iteration]:
|
||||||
chat_param = {
|
chat_param = {
|
||||||
"chat_session_id": uuid.uuid1(),
|
"chat_session_id": uuid.uuid1(),
|
||||||
"current_user_input": doc,
|
"current_user_input": "",
|
||||||
"select_param": "summary",
|
"select_param": doc,
|
||||||
"model_name": self.model_name,
|
"model_name": model_name,
|
||||||
}
|
}
|
||||||
tasks.append(
|
tasks.append(
|
||||||
llm_chat_response_nostream(
|
llm_chat_response_nostream(
|
||||||
@@ -561,14 +585,22 @@ class KnowledgeService:
|
|||||||
)
|
)
|
||||||
from pilot.common.chat_util import run_async_tasks
|
from pilot.common.chat_util import run_async_tasks
|
||||||
|
|
||||||
summary_iters = run_async_tasks(tasks)
|
summary_iters = await run_async_tasks(
|
||||||
|
tasks=tasks, concurrency_limit=concurrency_limit
|
||||||
|
)
|
||||||
|
summary_iters = list(
|
||||||
|
filter(
|
||||||
|
lambda content: "LLMServer Generate Error" not in content,
|
||||||
|
summary_iters,
|
||||||
|
)
|
||||||
|
)
|
||||||
from pilot.common.prompt_util import PromptHelper
|
from pilot.common.prompt_util import PromptHelper
|
||||||
from llama_index.prompts.default_prompt_selectors import (
|
from pilot.scene.chat_knowledge.summary.prompt import prompt
|
||||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_helper = PromptHelper(context_window=2500)
|
prompt_helper = PromptHelper()
|
||||||
summary_iters = prompt_helper.repack(
|
summary_iters = prompt_helper.repack(
|
||||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
|
prompt_template=prompt.template, text_chunks=summary_iters
|
||||||
|
)
|
||||||
|
return await self._mapreduce_extract_summary(
|
||||||
|
summary_iters, model_name, max_iteration, concurrency_limit
|
||||||
)
|
)
|
||||||
return self._mapreduce_extract_summary(summary_iters)
|
|
||||||
|
Reference in New Issue
Block a user