mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 01:27:14 +00:00
feat:add document summary
This commit is contained in:
@@ -21,35 +21,27 @@ async def llm_chat_response(chat_scene: str, **chat_param):
|
||||
return chat.stream_call()
|
||||
|
||||
|
||||
def run_async_tasks(
|
||||
async def run_async_tasks(
|
||||
tasks: List[Coroutine],
|
||||
show_progress: bool = False,
|
||||
progress_bar_desc: str = "Running async tasks",
|
||||
concurrency_limit: int = None,
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
if show_progress:
|
||||
try:
|
||||
import nest_asyncio
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def _tqdm_gather() -> List[Any]:
|
||||
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
|
||||
|
||||
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
|
||||
return tqdm_outputs
|
||||
# run the operation w/o tqdm on hitting a fatal
|
||||
# may occur in some environments where tqdm.asyncio
|
||||
# is not supported
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
if concurrency_limit:
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
outputs: List[Any] = asyncio.run(_gather())
|
||||
return outputs
|
||||
async def _execute_task(task):
|
||||
async with semaphore:
|
||||
return await task
|
||||
|
||||
# Execute tasks with semaphore limit
|
||||
return await asyncio.gather(
|
||||
*[_execute_task(task) for task in tasks_to_execute]
|
||||
)
|
||||
else:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
|
||||
# outputs: List[Any] = asyncio.run(_gather())
|
||||
return await _gather()
|
||||
|
448
pilot/common/global_helper.py
Normal file
448
pilot/common/global_helper.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""General utils functions."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
|
||||
class GlobalsHelper:
|
||||
"""Helper to retrieve globals.
|
||||
|
||||
Helpful for global caching of certain variables that can be expensive to load.
|
||||
(e.g. tokenization)
|
||||
|
||||
"""
|
||||
|
||||
_tokenizer: Optional[Callable[[str], List]] = None
|
||||
_stopwords: Optional[List[str]] = None
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Callable[[str], List]:
|
||||
"""Get tokenizer."""
|
||||
if self._tokenizer is None:
|
||||
tiktoken_import_err = (
|
||||
"`tiktoken` package not found, please run `pip install tiktoken`"
|
||||
)
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(tiktoken_import_err)
|
||||
enc = tiktoken.get_encoding("gpt2")
|
||||
self._tokenizer = cast(Callable[[str], List], enc.encode)
|
||||
self._tokenizer = partial(self._tokenizer, allowed_special="all")
|
||||
return self._tokenizer # type: ignore
|
||||
|
||||
@property
|
||||
def stopwords(self) -> List[str]:
|
||||
"""Get stopwords."""
|
||||
if self._stopwords is None:
|
||||
try:
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`nltk` package not found, please run `pip install nltk`"
|
||||
)
|
||||
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||
|
||||
# update nltk path for nltk so that it finds the data
|
||||
if nltk_data_dir not in nltk.data.path:
|
||||
nltk.data.path.append(nltk_data_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find("corpora/stopwords")
|
||||
except LookupError:
|
||||
nltk.download("stopwords", download_dir=nltk_data_dir)
|
||||
self._stopwords = stopwords.words("english")
|
||||
return self._stopwords
|
||||
|
||||
|
||||
globals_helper = GlobalsHelper()
|
||||
|
||||
|
||||
def get_new_id(d: Set) -> str:
|
||||
"""Get a new ID."""
|
||||
while True:
|
||||
new_id = str(uuid.uuid4())
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
def get_new_int_id(d: Set) -> int:
|
||||
"""Get a new integer ID."""
|
||||
while True:
|
||||
new_id = random.randint(0, sys.maxsize)
|
||||
if new_id not in d:
|
||||
break
|
||||
return new_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
|
||||
"""Temporary setter.
|
||||
|
||||
Utility class for setting a temporary value for an attribute on a class.
|
||||
Taken from: https://tinyurl.com/2p89xymh
|
||||
|
||||
"""
|
||||
prev_values = {k: getattr(obj, k) for k in kwargs}
|
||||
for k, v in kwargs.items():
|
||||
setattr(obj, k, v)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k, v in prev_values.items():
|
||||
setattr(obj, k, v)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorToRetry:
|
||||
"""Exception types that should be retried.
|
||||
|
||||
Args:
|
||||
exception_cls (Type[Exception]): Class of exception.
|
||||
check_fn (Optional[Callable[[Any]], bool]]):
|
||||
A function that takes an exception instance as input and returns
|
||||
whether to retry.
|
||||
|
||||
"""
|
||||
|
||||
exception_cls: Type[Exception]
|
||||
check_fn: Optional[Callable[[Any], bool]] = None
|
||||
|
||||
|
||||
def retry_on_exceptions_with_backoff(
|
||||
lambda_fn: Callable,
|
||||
errors_to_retry: List[ErrorToRetry],
|
||||
max_tries: int = 10,
|
||||
min_backoff_secs: float = 0.5,
|
||||
max_backoff_secs: float = 60.0,
|
||||
) -> Any:
|
||||
"""Execute lambda function with retries and exponential backoff.
|
||||
|
||||
Args:
|
||||
lambda_fn (Callable): Function to be called and output we want.
|
||||
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
|
||||
At least one needs to be provided.
|
||||
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
|
||||
min_backoff_secs (float): Minimum amount of backoff time between attempts.
|
||||
Defaults to 0.5.
|
||||
max_backoff_secs (float): Maximum amount of backoff time between attempts.
|
||||
Defaults to 60.
|
||||
|
||||
"""
|
||||
if not errors_to_retry:
|
||||
raise ValueError("At least one error to retry needs to be provided")
|
||||
|
||||
error_checks = {
|
||||
error_to_retry.exception_cls: error_to_retry.check_fn
|
||||
for error_to_retry in errors_to_retry
|
||||
}
|
||||
exception_class_tuples = tuple(error_checks.keys())
|
||||
|
||||
backoff_secs = min_backoff_secs
|
||||
tries = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
return lambda_fn()
|
||||
except exception_class_tuples as e:
|
||||
traceback.print_exc()
|
||||
tries += 1
|
||||
if tries >= max_tries:
|
||||
raise
|
||||
check_fn = error_checks.get(e.__class__)
|
||||
if check_fn and not check_fn(e):
|
||||
raise
|
||||
time.sleep(backoff_secs)
|
||||
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int) -> str:
|
||||
"""Truncate text to a maximum length."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
|
||||
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
||||
"""Iterate over an iterable in batches.
|
||||
|
||||
>>> list(iter_batch([1,2,3,4,5], 3))
|
||||
[[1, 2, 3], [4, 5]]
|
||||
"""
|
||||
source_iter = iter(iterable)
|
||||
while source_iter:
|
||||
b = list(islice(source_iter, size))
|
||||
if len(b) == 0:
|
||||
break
|
||||
yield b
|
||||
|
||||
|
||||
def concat_dirs(dirname: str, basename: str) -> str:
|
||||
"""
|
||||
Append basename to dirname, avoiding backslashes when running on windows.
|
||||
|
||||
os.path.join(dirname, basename) will add a backslash before dirname if
|
||||
basename does not end with a slash, so we make sure it does.
|
||||
"""
|
||||
dirname += "/" if dirname[-1] != "/" else ""
|
||||
return os.path.join(dirname, basename)
|
||||
|
||||
|
||||
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
|
||||
"""
|
||||
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
|
||||
"""
|
||||
_iterator = items
|
||||
if show_progress:
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
return tqdm(items, desc=desc)
|
||||
except ImportError:
|
||||
pass
|
||||
return _iterator
|
||||
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
tokens = globals_helper.tokenizer(text)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
|
||||
"""
|
||||
Args:
|
||||
model_name(str): the model name of the tokenizer.
|
||||
For instance, fxmarty/tiny-llama-fast-tokenizer.
|
||||
"""
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please run `pip install transformers`"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
return tokenizer.tokenize
|
||||
|
||||
|
||||
def get_cache_dir() -> str:
|
||||
"""Locate a platform-appropriate cache directory for llama_index,
|
||||
and create it if it doesn't yet exist.
|
||||
"""
|
||||
# User override
|
||||
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
|
||||
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
|
||||
|
||||
# Linux, Unix, AIX, etc.
|
||||
elif os.name == "posix" and sys.platform != "darwin":
|
||||
path = Path("/tmp/llama_index")
|
||||
|
||||
# Mac OS
|
||||
elif sys.platform == "darwin":
|
||||
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
|
||||
|
||||
# Windows (hopefully)
|
||||
else:
|
||||
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
|
||||
"~\\AppData\\Local"
|
||||
)
|
||||
path = Path(local, "llama_index")
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(
|
||||
path, exist_ok=True
|
||||
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
|
||||
return str(path)
|
||||
|
||||
|
||||
def add_sync_version(func: Any) -> Any:
|
||||
"""Decorator for adding sync version of an async function. The sync version
|
||||
is added as a function attribute to the original function, func.
|
||||
|
||||
Args:
|
||||
func(Any): the async function for which a sync variant will be built.
|
||||
"""
|
||||
assert asyncio.iscoroutinefunction(func)
|
||||
|
||||
@wraps(func)
|
||||
def _wrapper(*args: Any, **kwds: Any) -> Any:
|
||||
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
|
||||
|
||||
func.sync = _wrapper
|
||||
return func
|
||||
|
||||
|
||||
# Sample text from llama_index's readme
|
||||
SAMPLE_TEXT = """
|
||||
Context
|
||||
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
|
||||
They are pre-trained on large amounts of publicly available data.
|
||||
How do we best augment LLMs with our own private data?
|
||||
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
|
||||
|
||||
Proposed Solution
|
||||
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
|
||||
you build LLM apps. It provides the following tools:
|
||||
|
||||
Offers data connectors to ingest your existing data sources and data formats
|
||||
(APIs, PDFs, docs, SQL, etc.)
|
||||
Provides ways to structure your data (indices, graphs) so that this data can be
|
||||
easily used with LLMs.
|
||||
Provides an advanced retrieval/query interface over your data:
|
||||
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
|
||||
Allows easy integrations with your outer application framework
|
||||
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
|
||||
LlamaIndex provides tools for both beginner users and advanced users.
|
||||
Our high-level API allows beginner users to use LlamaIndex to ingest and
|
||||
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
|
||||
customize and extend any module (data connectors, indices, retrievers, query engines,
|
||||
reranking modules), to fit their needs.
|
||||
"""
|
||||
|
||||
_LLAMA_INDEX_COLORS = {
|
||||
"llama_pink": "38;2;237;90;200",
|
||||
"llama_blue": "38;2;90;149;237",
|
||||
"llama_turquoise": "38;2;11;159;203",
|
||||
"llama_lavender": "38;2;155;135;227",
|
||||
}
|
||||
|
||||
_ANSI_COLORS = {
|
||||
"red": "31",
|
||||
"green": "32",
|
||||
"yellow": "33",
|
||||
"blue": "34",
|
||||
"magenta": "35",
|
||||
"cyan": "36",
|
||||
"pink": "38;5;200",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], use_llama_index_colors: bool = True
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get a mapping of items to colors.
|
||||
|
||||
Args:
|
||||
items (List[str]): List of items to be mapped to colors.
|
||||
use_llama_index_colors (bool, optional): Flag to indicate
|
||||
whether to use LlamaIndex colors or ANSI colors.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: Mapping of items to colors.
|
||||
"""
|
||||
if use_llama_index_colors:
|
||||
color_palette = _LLAMA_INDEX_COLORS
|
||||
else:
|
||||
color_palette = _ANSI_COLORS
|
||||
|
||||
colors = list(color_palette.keys())
|
||||
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
|
||||
|
||||
def _get_colored_text(text: str, color: str) -> str:
|
||||
"""
|
||||
Get the colored version of the input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
color (str): Color to be applied to the text.
|
||||
|
||||
Returns:
|
||||
str: Colored version of the input text.
|
||||
"""
|
||||
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
|
||||
|
||||
if color not in all_colors:
|
||||
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
|
||||
|
||||
color = all_colors[color]
|
||||
|
||||
return f"\033[1;3;{color}m{text}\033[0m"
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""
|
||||
Print the text with the specified color.
|
||||
|
||||
Args:
|
||||
text (str): Text to be printed.
|
||||
color (str, optional): Color to be applied to the text. Supported colors are:
|
||||
llama_pink, llama_blue, llama_turquoise, llama_lavender,
|
||||
red, green, yellow, blue, magenta, cyan, pink.
|
||||
end (str, optional): String appended after the last character of the text.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
text_to_print = _get_colored_text(text, color) if color is not None else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
|
||||
def infer_torch_device() -> str:
|
||||
"""Infer the input to torch.device."""
|
||||
try:
|
||||
has_cuda = torch.cuda.is_available()
|
||||
except NameError:
|
||||
import torch
|
||||
|
||||
has_cuda = torch.cuda.is_available()
|
||||
if has_cuda:
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def unit_generator(x: Any) -> Generator[Any, None, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
||||
|
||||
|
||||
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
|
||||
"""A function that returns a generator of a single element.
|
||||
|
||||
Args:
|
||||
x (Any): the element to build yield
|
||||
|
||||
Yields:
|
||||
Any: the single element
|
||||
"""
|
||||
yield x
|
43
pilot/common/llm_metadata.py
Normal file
43
pilot/common/llm_metadata.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3900
|
||||
DEFAULT_NUM_OUTPUTS = 256
|
||||
|
||||
|
||||
class LLMMetadata(BaseModel):
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description=(
|
||||
"Total number of tokens the model can be input and output for one response."
|
||||
),
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="Number of tokens the model can output when generating a response.",
|
||||
)
|
||||
is_chat_model: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Set True if the model exposes a chat interface (i.e. can be passed a"
|
||||
" sequence of messages, rather than text), like OpenAI's"
|
||||
" /v1/chat/completions endpoint."
|
||||
),
|
||||
)
|
||||
is_function_calling_model: bool = Field(
|
||||
default=False,
|
||||
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
|
||||
description=(
|
||||
"Set True if the model supports function calling messages, similar to"
|
||||
" OpenAI's function calling API. For example, converting 'Email Anya to"
|
||||
" see if she wants to get coffee next Friday' to a function call like"
|
||||
" `send_email(to: string, body: string)`."
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="unknown",
|
||||
description=(
|
||||
"The model's name used for logging, testing, and sanity checking. For some"
|
||||
" models this can be automatically discerned. For other models, like"
|
||||
" locally loaded models, this must be manually specified."
|
||||
),
|
||||
)
|
254
pilot/common/prompt_util.py
Normal file
254
pilot/common/prompt_util.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context window
|
||||
size of an LLM and reserve token space for the prompt template, and the output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM calls
|
||||
needed), or truncating them so that they fit in a single LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence
|
||||
|
||||
from pydantic import Field, PrivateAttr, BaseModel
|
||||
from llama_index.prompts import BasePromptTemplate
|
||||
|
||||
from pilot.common.global_helper import globals_helper
|
||||
from pilot.common.llm_metadata import LLMMetadata
|
||||
from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter
|
||||
|
||||
DEFAULT_PADDING = 5
|
||||
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
|
||||
DEFAULT_NUM_OUTPUTS = 256 # tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptHelper(BaseModel):
|
||||
"""Prompt helper.
|
||||
|
||||
General prompt helper that can help deal with LLM context window token limitations.
|
||||
|
||||
At its core, it calculates available context size by starting with the context
|
||||
window size of an LLM and reserve token space for the prompt template, and the
|
||||
output.
|
||||
|
||||
It provides utility for "repacking" text chunks (retrieved from index) to maximally
|
||||
make use of the available context window (and thereby reducing the number of LLM
|
||||
calls needed), or truncating them so that they fit in a single LLM call.
|
||||
|
||||
Args:
|
||||
context_window (int): Context window for the LLM.
|
||||
num_output (int): Number of outputs for the LLM.
|
||||
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
|
||||
chunk_size_limit (Optional[int]): Maximum chunk size to use.
|
||||
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
|
||||
separator (str): Separator for text splitter
|
||||
|
||||
"""
|
||||
|
||||
context_window: int = Field(
|
||||
default=DEFAULT_CONTEXT_WINDOW,
|
||||
description="The maximum context size that will get sent to the LLM.",
|
||||
)
|
||||
num_output: int = Field(
|
||||
default=DEFAULT_NUM_OUTPUTS,
|
||||
description="The amount of token-space to leave in input for generation.",
|
||||
)
|
||||
chunk_overlap_ratio: float = Field(
|
||||
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
description="The percentage token amount that each chunk should overlap.",
|
||||
)
|
||||
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
|
||||
separator: str = Field(
|
||||
default=" ", description="The separator when chunking tokens."
|
||||
)
|
||||
|
||||
_tokenizer: Callable[[str], List] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_window: int = DEFAULT_CONTEXT_WINDOW,
|
||||
num_output: int = DEFAULT_NUM_OUTPUTS,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
||||
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
||||
|
||||
# TODO: make configurable
|
||||
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
super().__init__(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_metadata(
|
||||
cls,
|
||||
llm_metadata: LLMMetadata,
|
||||
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
) -> "PromptHelper":
|
||||
"""Create from llm predictor.
|
||||
|
||||
This will autofill values like context_window and num_output.
|
||||
|
||||
"""
|
||||
context_window = llm_metadata.context_window
|
||||
if llm_metadata.num_output == -1:
|
||||
num_output = DEFAULT_NUM_OUTPUTS
|
||||
else:
|
||||
num_output = llm_metadata.num_output
|
||||
|
||||
return cls(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
tokenizer=tokenizer,
|
||||
separator=separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "PromptHelper"
|
||||
|
||||
def _get_available_context_size(self, template: str) -> int:
|
||||
"""Get available context size.
|
||||
|
||||
This is calculated as:
|
||||
available context window = total context window
|
||||
- input (partially filled prompt)
|
||||
- output (room reserved for response)
|
||||
|
||||
Notes:
|
||||
- Available context size is further clamped to be non-negative.
|
||||
"""
|
||||
empty_prompt_txt = get_empty_prompt_txt(template)
|
||||
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
|
||||
context_size_tokens = (
|
||||
self.context_window - num_empty_prompt_tokens - self.num_output
|
||||
)
|
||||
if context_size_tokens < 0:
|
||||
raise ValueError(
|
||||
f"Calculated available context size {context_size_tokens} was"
|
||||
" not non-negative."
|
||||
)
|
||||
return context_size_tokens
|
||||
|
||||
def _get_available_chunk_size(
|
||||
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
|
||||
) -> int:
|
||||
"""Get available chunk size.
|
||||
|
||||
This is calculated as:
|
||||
available chunk size = available context window // number_chunks
|
||||
- padding
|
||||
|
||||
Notes:
|
||||
- By default, we use padding of 5 (to save space for formatting needs).
|
||||
- Available chunk size is further clamped to chunk_size_limit if specified.
|
||||
"""
|
||||
available_context_size = self._get_available_context_size(prompt_template)
|
||||
result = available_context_size // num_chunks - padding
|
||||
if self.chunk_size_limit is not None:
|
||||
result = min(result, self.chunk_size_limit)
|
||||
return result
|
||||
|
||||
def get_text_splitter_given_prompt(
|
||||
self,
|
||||
prompt_template: str,
|
||||
num_chunks: int = 1,
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> TokenTextSplitter:
|
||||
"""Get text splitter configured to maximally pack available context window,
|
||||
taking into account of given prompt, and desired number of chunks.
|
||||
"""
|
||||
chunk_size = self._get_available_chunk_size(
|
||||
prompt_template, num_chunks, padding=padding
|
||||
)
|
||||
if chunk_size <= 0:
|
||||
raise ValueError(f"Chunk size {chunk_size} is not positive.")
|
||||
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
|
||||
return TokenTextSplitter(
|
||||
separator=self.separator,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
tokenizer=self._tokenizer,
|
||||
)
|
||||
|
||||
# def truncate(
|
||||
# self,
|
||||
# prompt: BasePromptTemplate,
|
||||
# text_chunks: Sequence[str],
|
||||
# padding: int = DEFAULT_PADDING,
|
||||
# ) -> List[str]:
|
||||
# """Truncate text chunks to fit available context window."""
|
||||
# text_splitter = self.get_text_splitter_given_prompt(
|
||||
# prompt,
|
||||
# num_chunks=len(text_chunks),
|
||||
# padding=padding,
|
||||
# )
|
||||
# return [truncate_text(chunk, text_splitter) for chunk in text_chunks]
|
||||
|
||||
def repack(
|
||||
self,
|
||||
prompt_template: str,
|
||||
text_chunks: Sequence[str],
|
||||
padding: int = DEFAULT_PADDING,
|
||||
) -> List[str]:
|
||||
"""Repack text chunks to fit available context window.
|
||||
|
||||
This will combine text chunks into consolidated chunks
|
||||
that more fully "pack" the prompt template given the context_window.
|
||||
|
||||
"""
|
||||
text_splitter = self.get_text_splitter_given_prompt(
|
||||
prompt_template, padding=padding
|
||||
)
|
||||
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
|
||||
return text_splitter.split_text(combined_str)
|
||||
|
||||
|
||||
def get_empty_prompt_txt(template: str) -> str:
|
||||
"""Get empty prompt text.
|
||||
|
||||
Substitute empty strings in parts of the prompt that have
|
||||
not yet been filled out. Skip variables that have already
|
||||
been partially formatted. This is used to compute the initial tokens.
|
||||
|
||||
"""
|
||||
# partial_kargs = prompt.kwargs
|
||||
|
||||
partial_kargs = {}
|
||||
template_vars = get_template_vars(template)
|
||||
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
|
||||
all_kwargs = {**partial_kargs, **empty_kwargs}
|
||||
prompt = template.format(**all_kwargs)
|
||||
return prompt
|
||||
|
||||
|
||||
def get_template_vars(template_str: str) -> List[str]:
|
||||
"""Get template variables from a template string."""
|
||||
variables = []
|
||||
formatter = Formatter()
|
||||
|
||||
for _, variable_name, _, _ in formatter.parse(template_str):
|
||||
if variable_name:
|
||||
variables.append(variable_name)
|
||||
|
||||
return variables
|
82
pilot/embedding_engine/loader/splitter_utils.py
Normal file
82
pilot/embedding_engine/loader/splitter_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Callable, List
|
||||
|
||||
|
||||
def split_text_keep_separator(text: str, separator: str) -> List[str]:
|
||||
"""Split text with separator and keep the separator at the end of each split."""
|
||||
parts = text.split(separator)
|
||||
result = [separator + s if i > 0 else s for i, s in enumerate(parts)]
|
||||
return [s for s in result if s]
|
||||
|
||||
|
||||
def split_by_sep(sep: str, keep_sep: bool = True) -> Callable[[str], List[str]]:
|
||||
"""Split text by separator."""
|
||||
if keep_sep:
|
||||
return lambda text: split_text_keep_separator(text, sep)
|
||||
else:
|
||||
return lambda text: text.split(sep)
|
||||
|
||||
|
||||
def split_by_char() -> Callable[[str], List[str]]:
|
||||
"""Split text by character."""
|
||||
return lambda text: list(text)
|
||||
|
||||
|
||||
def split_by_sentence_tokenizer() -> Callable[[str], List[str]]:
|
||||
import os
|
||||
|
||||
import nltk
|
||||
|
||||
from llama_index.utils import get_cache_dir
|
||||
|
||||
cache_dir = get_cache_dir()
|
||||
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
||||
|
||||
# update nltk path for nltk so that it finds the data
|
||||
if nltk_data_dir not in nltk.data.path:
|
||||
nltk.data.path.append(nltk_data_dir)
|
||||
|
||||
try:
|
||||
nltk.data.find("tokenizers/punkt")
|
||||
except LookupError:
|
||||
nltk.download("punkt", download_dir=nltk_data_dir)
|
||||
|
||||
tokenizer = nltk.tokenize.PunktSentenceTokenizer()
|
||||
|
||||
# get the spans and then return the sentences
|
||||
# using the start index of each span
|
||||
# instead of using end, use the start of the next span if available
|
||||
def split(text: str) -> List[str]:
|
||||
spans = list(tokenizer.span_tokenize(text))
|
||||
sentences = []
|
||||
for i, span in enumerate(spans):
|
||||
start = span[0]
|
||||
if i < len(spans) - 1:
|
||||
end = spans[i + 1][0]
|
||||
else:
|
||||
end = len(text)
|
||||
sentences.append(text[start:end])
|
||||
|
||||
return sentences
|
||||
|
||||
return split
|
||||
|
||||
|
||||
def split_by_regex(regex: str) -> Callable[[str], List[str]]:
|
||||
"""Split text by regex."""
|
||||
import re
|
||||
|
||||
return lambda text: re.findall(regex, text)
|
||||
|
||||
|
||||
def split_by_phrase_regex() -> Callable[[str], List[str]]:
|
||||
"""Split text by phrase regex.
|
||||
|
||||
This regular expression will split the sentences into phrases,
|
||||
where each phrase is a sequence of one or more non-comma,
|
||||
non-period, and non-semicolon characters, followed by an optional comma,
|
||||
period, or semicolon. The regular expression will also capture the
|
||||
delimiters themselves as separate items in the list of phrases.
|
||||
"""
|
||||
regex = "[^,.;。]+[,.;。]?"
|
||||
return split_by_regex(regex)
|
||||
|
184
pilot/embedding_engine/loader/token_splitter.py
Normal file
184
pilot/embedding_engine/loader/token_splitter.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Token splitter."""
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from pydantic import Field, PrivateAttr, BaseModel
|
||||
|
||||
|
||||
from pilot.common.global_helper import globals_helper
|
||||
from pilot.embedding_engine.loader.splitter_utils import split_by_sep, split_by_char
|
||||
|
||||
DEFAULT_METADATA_FORMAT_LEN = 2
|
||||
DEFAULT_CHUNK_OVERLAP = 20
|
||||
DEFAULT_CHUNK_SIZE = 1024
|
||||
|
||||
|
||||
class TokenTextSplitter(BaseModel):
|
||||
"""Implementation of splitting text that looks at word tokens."""
|
||||
|
||||
chunk_size: int = Field(
|
||||
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
|
||||
)
|
||||
chunk_overlap: int = Field(
|
||||
default=DEFAULT_CHUNK_OVERLAP,
|
||||
description="The token overlap of each chunk when splitting.",
|
||||
)
|
||||
separator: str = Field(
|
||||
default=" ", description="Default separator for splitting into words"
|
||||
)
|
||||
backup_separators: List = Field(
|
||||
default_factory=list, description="Additional separators for splitting."
|
||||
)
|
||||
# callback_manager: CallbackManager = Field(
|
||||
# default_factory=CallbackManager, exclude=True
|
||||
# )
|
||||
tokenizer: Callable = Field(
|
||||
default_factory=globals_helper.tokenizer, # type: ignore
|
||||
description="Tokenizer for splitting words into tokens.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
_split_fns: List[Callable] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
tokenizer: Optional[Callable] = None,
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
separator: str = " ",
|
||||
backup_separators: Optional[List[str]] = ["\n"],
|
||||
):
|
||||
"""Initialize with parameters."""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
f"({chunk_size}), should be smaller."
|
||||
)
|
||||
# callback_manager = callback_manager or CallbackManager([])
|
||||
tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
all_seps = [separator] + (backup_separators or [])
|
||||
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
||||
|
||||
super().__init__(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
separator=separator,
|
||||
backup_separators=backup_separators,
|
||||
# callback_manager=callback_manager,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "TokenTextSplitter"
|
||||
|
||||
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
|
||||
"""Split text into chunks, reserving space required for metadata str."""
|
||||
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
|
||||
effective_chunk_size = self.chunk_size - metadata_len
|
||||
if effective_chunk_size <= 0:
|
||||
raise ValueError(
|
||||
f"Metadata length ({metadata_len}) is longer than chunk size "
|
||||
f"({self.chunk_size}). Consider increasing the chunk size or "
|
||||
"decreasing the size of your metadata to avoid this."
|
||||
)
|
||||
elif effective_chunk_size < 50:
|
||||
print(
|
||||
f"Metadata length ({metadata_len}) is close to chunk size "
|
||||
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
|
||||
"Consider increasing the chunk size or decreasing the size of "
|
||||
"your metadata to avoid this.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return self._split_text(text, chunk_size=effective_chunk_size)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split text into chunks."""
|
||||
return self._split_text(text, chunk_size=self.chunk_size)
|
||||
|
||||
def _split_text(self, text: str, chunk_size: int) -> List[str]:
|
||||
"""Split text into chunks up to chunk_size."""
|
||||
if text == "":
|
||||
return []
|
||||
|
||||
splits = self._split(text, chunk_size)
|
||||
chunks = self._merge(splits, chunk_size)
|
||||
return chunks
|
||||
|
||||
def _split(self, text: str, chunk_size: int) -> List[str]:
|
||||
"""Break text into splits that are smaller than chunk size.
|
||||
|
||||
The order of splitting is:
|
||||
1. split by separator
|
||||
2. split by backup separators (if any)
|
||||
3. split by characters
|
||||
|
||||
NOTE: the splits contain the separators.
|
||||
"""
|
||||
if len(self.tokenizer(text)) <= chunk_size:
|
||||
return [text]
|
||||
|
||||
for split_fn in self._split_fns:
|
||||
splits = split_fn(text)
|
||||
if len(splits) > 1:
|
||||
break
|
||||
|
||||
new_splits = []
|
||||
for split in splits:
|
||||
split_len = len(self.tokenizer(split))
|
||||
if split_len <= chunk_size:
|
||||
new_splits.append(split)
|
||||
else:
|
||||
# recursively split
|
||||
new_splits.extend(self._split(split, chunk_size=chunk_size))
|
||||
return new_splits
|
||||
|
||||
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
|
||||
"""Merge splits into chunks.
|
||||
|
||||
The high-level idea is to keep adding splits to a chunk until we
|
||||
exceed the chunk size, then we start a new chunk with overlap.
|
||||
|
||||
When we start a new chunk, we pop off the first element of the previous
|
||||
chunk until the total length is less than the chunk size.
|
||||
"""
|
||||
chunks: List[str] = []
|
||||
|
||||
cur_chunk: List[str] = []
|
||||
cur_len = 0
|
||||
for split in splits:
|
||||
split_len = len(self.tokenizer(split))
|
||||
if split_len > chunk_size:
|
||||
print(
|
||||
f"Got a split of size {split_len}, ",
|
||||
f"larger than chunk size {chunk_size}.",
|
||||
)
|
||||
|
||||
# if we exceed the chunk size after adding the new split, then
|
||||
# we need to end the current chunk and start a new one
|
||||
if cur_len + split_len > chunk_size:
|
||||
# end the previous chunk
|
||||
chunk = "".join(cur_chunk).strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# start a new chunk with overlap
|
||||
# keep popping off the first element of the previous chunk until:
|
||||
# 1. the current chunk length is less than chunk overlap
|
||||
# 2. the total length is less than chunk size
|
||||
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
|
||||
# pop off the first element
|
||||
first_chunk = cur_chunk.pop(0)
|
||||
cur_len -= len(self.tokenizer(first_chunk))
|
||||
|
||||
cur_chunk.append(split)
|
||||
cur_len += split_len
|
||||
|
||||
# handle the last chunk
|
||||
chunk = "".join(cur_chunk).strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
@@ -12,11 +12,12 @@ CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结"""
|
||||
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请根据你之前推理的内容进行最终的总结,总结的时候可以详细点,回答的时候最好按照1.2.3.进行总结."""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
|
||||
\nWhen answering, it is best to summarize according to points 1.2.3.
|
||||
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
|
||||
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1, 2, and 3. and etc.
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
@@ -27,7 +28,7 @@ PROMPT_RESPONSE = """"""
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ExtractRefineSummary.value(),
|
||||
|
@@ -21,7 +21,8 @@ class ExtractSummary(BaseChat):
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
# self.user_input = chat_param["current_user_input"]
|
||||
self.user_input = chat_param["select_param"]
|
||||
# self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
|
@@ -11,15 +11,15 @@ CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
|
||||
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
|
||||
{context}
|
||||
回答的时候最好按照1.2.3.点进行总结
|
||||
答案尽量精确和简单,不要过长,长度控制在100字左右
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Write a summary of the following context:
|
||||
Write a quick summary of the following context:
|
||||
{context}
|
||||
When answering, it is best to summarize according to points 1.2.3.
|
||||
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
|
@@ -10,6 +10,7 @@ from pilot.configs.model_config import (
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from pilot.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
|
||||
|
||||
from pilot.openapi.api_view_model import Result
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
@@ -25,9 +26,11 @@ from pilot.server.knowledge.request.request import (
|
||||
DocumentQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
EntityExtractRequest,
|
||||
DocumentSummaryRequest,
|
||||
)
|
||||
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from pilot.utils.tracer import root_tracer, SpanType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -201,6 +204,45 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
|
||||
return {"response": res}
|
||||
|
||||
|
||||
@router.post("/knowledge/document/summary")
|
||||
async def document_summary(request: DocumentSummaryRequest):
|
||||
print(f"/document/summary params: {request}")
|
||||
try:
|
||||
with root_tracer.start_span(
|
||||
"get_chat_instance", span_type=SpanType.CHAT, metadata=request
|
||||
):
|
||||
chat = await knowledge_space_service.document_summary(request=request)
|
||||
headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
}
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return StreamingResponse(
|
||||
no_stream_generator(chat),
|
||||
headers=headers,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
stream_generator(chat, False, request.model_name),
|
||||
headers=headers,
|
||||
media_type="text/plain",
|
||||
)
|
||||
|
||||
# return Result.succ(
|
||||
# knowledge_space_service.create_knowledge_document(
|
||||
# space=space_name, request=request
|
||||
# )
|
||||
# )
|
||||
# return Result.succ([])
|
||||
except Exception as e:
|
||||
return Result.faild(code="E000X", msg=f"document add error {e}")
|
||||
|
||||
|
||||
@router.post("/knowledge/entity/extract")
|
||||
async def entity_extract(request: EntityExtractRequest):
|
||||
logger.info(f"Received params: {request}")
|
||||
|
@@ -108,6 +108,14 @@ class SpaceArgumentRequest(BaseModel):
|
||||
argument: str
|
||||
|
||||
|
||||
class DocumentSummaryRequest(BaseModel):
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class EntityExtractRequest(BaseModel):
|
||||
"""argument: argument"""
|
||||
|
||||
|
@@ -10,7 +10,7 @@ from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
)
|
||||
from pilot.component import ComponentType
|
||||
from pilot.utils.executor_utils import ExecutorFactory
|
||||
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
from pilot.server.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
@@ -31,6 +31,7 @@ from pilot.server.knowledge.request.request import (
|
||||
ChunkQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
DocumentSyncRequest,
|
||||
DocumentSummaryRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@@ -303,7 +304,30 @@ class KnowledgeService:
|
||||
]
|
||||
document_chunk_dao.create_documents_chunks(chunk_entities)
|
||||
|
||||
return True
|
||||
return doc.id
|
||||
|
||||
async def document_summary(self, request: DocumentSummaryRequest):
|
||||
"""get document summary
|
||||
Args:
|
||||
- request: DocumentSummaryRequest
|
||||
"""
|
||||
doc_query = KnowledgeDocumentEntity(id=request.doc_id)
|
||||
documents = knowledge_document_dao.get_documents(doc_query)
|
||||
if len(documents) != 1:
|
||||
raise Exception(f"can not found document for {request.doc_id}")
|
||||
document = documents[0]
|
||||
query = DocumentChunkEntity(
|
||||
document_id=request.doc_id,
|
||||
)
|
||||
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
|
||||
if len(chunks) == 0:
|
||||
raise Exception(f"can not found chunks for {request.doc_id}")
|
||||
from langchain.schema import Document
|
||||
|
||||
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
|
||||
return await self.async_document_summary(
|
||||
model_name=request.model_name, chunk_docs=chunk_docs, doc=document
|
||||
)
|
||||
|
||||
def update_knowledge_space(
|
||||
self, space_id: int, space_request: KnowledgeSpaceRequest
|
||||
@@ -417,30 +441,25 @@ class KnowledgeService:
|
||||
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
|
||||
def async_document_summary(self, chunk_docs, doc):
|
||||
async def async_document_summary(self, model_name, chunk_docs, doc):
|
||||
"""async document extract summary
|
||||
Args:
|
||||
- model_name: str
|
||||
- chunk_docs: List[Document]
|
||||
- doc: KnowledgeDocumentEntity
|
||||
"""
|
||||
from llama_index import PromptHelper
|
||||
from llama_index.prompts.default_prompt_selectors import (
|
||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
||||
)
|
||||
|
||||
texts = [doc.page_content for doc in chunk_docs]
|
||||
prompt_helper = PromptHelper(context_window=2000)
|
||||
from pilot.common.prompt_util import PromptHelper
|
||||
|
||||
texts = prompt_helper.repack(
|
||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
|
||||
)
|
||||
prompt_helper = PromptHelper()
|
||||
from pilot.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
|
||||
logger.info(
|
||||
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
|
||||
)
|
||||
summary = self._mapreduce_extract_summary(texts)
|
||||
print(f"final summary:{summary}")
|
||||
doc.summary = summary
|
||||
return knowledge_document_dao.update_knowledge_document(doc)
|
||||
summary = await self._mapreduce_extract_summary(texts, model_name, 10, 3)
|
||||
return await self._llm_extract_summary(summary, model_name)
|
||||
|
||||
def async_doc_embedding(self, client, chunk_docs, doc):
|
||||
"""async document embedding into vector db
|
||||
@@ -506,30 +525,37 @@ class KnowledgeService:
|
||||
return json.loads(spaces[0].context)
|
||||
return None
|
||||
|
||||
def _llm_extract_summary(self, doc: str):
|
||||
async def _llm_extract_summary(self, doc: str, model_name: str = None):
|
||||
"""Extract triplets from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
import uuid
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": self.model_name,
|
||||
"model_name": model_name,
|
||||
}
|
||||
from pilot.common.chat_util import run_async_tasks
|
||||
executor = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
from pilot.openapi.api_v1.api_v1 import CHAT_FACTORY
|
||||
|
||||
summary_iters = run_async_tasks(
|
||||
[
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
]
|
||||
chat = await blocking_func_to_async(
|
||||
executor,
|
||||
CHAT_FACTORY.get_implementation,
|
||||
ChatScene.ExtractRefineSummary.value(),
|
||||
**{"chat_param": chat_param},
|
||||
)
|
||||
return summary_iters[0]
|
||||
return chat
|
||||
|
||||
def _mapreduce_extract_summary(self, docs):
|
||||
async def _mapreduce_extract_summary(
|
||||
self,
|
||||
docs,
|
||||
model_name: str = None,
|
||||
max_iteration: int = 5,
|
||||
concurrency_limit: int = None,
|
||||
):
|
||||
"""Extract summary by mapreduce mode
|
||||
map -> multi async thread generate summary
|
||||
reduce -> merge the summaries by map process
|
||||
@@ -541,18 +567,16 @@ class KnowledgeService:
|
||||
import uuid
|
||||
|
||||
tasks = []
|
||||
max_iteration = 5
|
||||
if len(docs) == 1:
|
||||
summary = self._llm_extract_summary(doc=docs[0])
|
||||
return summary
|
||||
return docs[0]
|
||||
else:
|
||||
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
|
||||
for doc in docs[0:max_iteration]:
|
||||
chat_param = {
|
||||
"chat_session_id": uuid.uuid1(),
|
||||
"current_user_input": doc,
|
||||
"select_param": "summary",
|
||||
"model_name": self.model_name,
|
||||
"current_user_input": "",
|
||||
"select_param": doc,
|
||||
"model_name": model_name,
|
||||
}
|
||||
tasks.append(
|
||||
llm_chat_response_nostream(
|
||||
@@ -561,14 +585,22 @@ class KnowledgeService:
|
||||
)
|
||||
from pilot.common.chat_util import run_async_tasks
|
||||
|
||||
summary_iters = run_async_tasks(tasks)
|
||||
summary_iters = await run_async_tasks(
|
||||
tasks=tasks, concurrency_limit=concurrency_limit
|
||||
)
|
||||
summary_iters = list(
|
||||
filter(
|
||||
lambda content: "LLMServer Generate Error" not in content,
|
||||
summary_iters,
|
||||
)
|
||||
)
|
||||
from pilot.common.prompt_util import PromptHelper
|
||||
from llama_index.prompts.default_prompt_selectors import (
|
||||
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
|
||||
)
|
||||
from pilot.scene.chat_knowledge.summary.prompt import prompt
|
||||
|
||||
prompt_helper = PromptHelper(context_window=2500)
|
||||
prompt_helper = PromptHelper()
|
||||
summary_iters = prompt_helper.repack(
|
||||
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
|
||||
prompt_template=prompt.template, text_chunks=summary_iters
|
||||
)
|
||||
return await self._mapreduce_extract_summary(
|
||||
summary_iters, model_name, max_iteration, concurrency_limit
|
||||
)
|
||||
return self._mapreduce_extract_summary(summary_iters)
|
||||
|
Reference in New Issue
Block a user