feat:add document summary

This commit is contained in:
aries_ckt
2023-11-09 15:14:08 +08:00
parent 4a52d1d8a0
commit 1b8a67851b
12 changed files with 1163 additions and 76 deletions

View File

@@ -21,35 +21,27 @@ async def llm_chat_response(chat_scene: str, **chat_param):
return chat.stream_call()
def run_async_tasks(
async def run_async_tasks(
tasks: List[Coroutine],
show_progress: bool = False,
progress_bar_desc: str = "Running async tasks",
concurrency_limit: int = None,
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
if show_progress:
try:
import nest_asyncio
from tqdm.asyncio import tqdm
nest_asyncio.apply()
loop = asyncio.get_event_loop()
async def _tqdm_gather() -> List[Any]:
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
return tqdm_outputs
# run the operation w/o tqdm on hitting a fatal
# may occur in some environments where tqdm.asyncio
# is not supported
except Exception:
pass
async def _gather() -> List[Any]:
if concurrency_limit:
semaphore = asyncio.Semaphore(concurrency_limit)
async def _execute_task(task):
async with semaphore:
return await task
# Execute tasks with semaphore limit
return await asyncio.gather(
*[_execute_task(task) for task in tasks_to_execute]
)
else:
return await asyncio.gather(*tasks_to_execute)
outputs: List[Any] = asyncio.run(_gather())
return outputs
# outputs: List[Any] = asyncio.run(_gather())
return await _gather()

View File

@@ -0,0 +1,448 @@
"""General utils functions."""
import asyncio
import os
import random
import sys
import time
import traceback
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from itertools import islice
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Type,
Union,
cast,
)
class GlobalsHelper:
"""Helper to retrieve globals.
Helpful for global caching of certain variables that can be expensive to load.
(e.g. tokenization)
"""
_tokenizer: Optional[Callable[[str], List]] = None
_stopwords: Optional[List[str]] = None
@property
def tokenizer(self) -> Callable[[str], List]:
"""Get tokenizer."""
if self._tokenizer is None:
tiktoken_import_err = (
"`tiktoken` package not found, please run `pip install tiktoken`"
)
try:
import tiktoken
except ImportError:
raise ImportError(tiktoken_import_err)
enc = tiktoken.get_encoding("gpt2")
self._tokenizer = cast(Callable[[str], List], enc.encode)
self._tokenizer = partial(self._tokenizer, allowed_special="all")
return self._tokenizer # type: ignore
@property
def stopwords(self) -> List[str]:
"""Get stopwords."""
if self._stopwords is None:
try:
import nltk
from nltk.corpus import stopwords
except ImportError:
raise ImportError(
"`nltk` package not found, please run `pip install nltk`"
)
from llama_index.utils import get_cache_dir
cache_dir = get_cache_dir()
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
# update nltk path for nltk so that it finds the data
if nltk_data_dir not in nltk.data.path:
nltk.data.path.append(nltk_data_dir)
try:
nltk.data.find("corpora/stopwords")
except LookupError:
nltk.download("stopwords", download_dir=nltk_data_dir)
self._stopwords = stopwords.words("english")
return self._stopwords
globals_helper = GlobalsHelper()
def get_new_id(d: Set) -> str:
"""Get a new ID."""
while True:
new_id = str(uuid.uuid4())
if new_id not in d:
break
return new_id
def get_new_int_id(d: Set) -> int:
"""Get a new integer ID."""
while True:
new_id = random.randint(0, sys.maxsize)
if new_id not in d:
break
return new_id
@contextmanager
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
"""Temporary setter.
Utility class for setting a temporary value for an attribute on a class.
Taken from: https://tinyurl.com/2p89xymh
"""
prev_values = {k: getattr(obj, k) for k in kwargs}
for k, v in kwargs.items():
setattr(obj, k, v)
try:
yield
finally:
for k, v in prev_values.items():
setattr(obj, k, v)
@dataclass
class ErrorToRetry:
"""Exception types that should be retried.
Args:
exception_cls (Type[Exception]): Class of exception.
check_fn (Optional[Callable[[Any]], bool]]):
A function that takes an exception instance as input and returns
whether to retry.
"""
exception_cls: Type[Exception]
check_fn: Optional[Callable[[Any], bool]] = None
def retry_on_exceptions_with_backoff(
lambda_fn: Callable,
errors_to_retry: List[ErrorToRetry],
max_tries: int = 10,
min_backoff_secs: float = 0.5,
max_backoff_secs: float = 60.0,
) -> Any:
"""Execute lambda function with retries and exponential backoff.
Args:
lambda_fn (Callable): Function to be called and output we want.
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
At least one needs to be provided.
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
min_backoff_secs (float): Minimum amount of backoff time between attempts.
Defaults to 0.5.
max_backoff_secs (float): Maximum amount of backoff time between attempts.
Defaults to 60.
"""
if not errors_to_retry:
raise ValueError("At least one error to retry needs to be provided")
error_checks = {
error_to_retry.exception_cls: error_to_retry.check_fn
for error_to_retry in errors_to_retry
}
exception_class_tuples = tuple(error_checks.keys())
backoff_secs = min_backoff_secs
tries = 0
while True:
try:
return lambda_fn()
except exception_class_tuples as e:
traceback.print_exc()
tries += 1
if tries >= max_tries:
raise
check_fn = error_checks.get(e.__class__)
if check_fn and not check_fn(e):
raise
time.sleep(backoff_secs)
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
def truncate_text(text: str, max_length: int) -> str:
"""Truncate text to a maximum length."""
if len(text) <= max_length:
return text
return text[: max_length - 3] + "..."
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
"""Iterate over an iterable in batches.
>>> list(iter_batch([1,2,3,4,5], 3))
[[1, 2, 3], [4, 5]]
"""
source_iter = iter(iterable)
while source_iter:
b = list(islice(source_iter, size))
if len(b) == 0:
break
yield b
def concat_dirs(dirname: str, basename: str) -> str:
"""
Append basename to dirname, avoiding backslashes when running on windows.
os.path.join(dirname, basename) will add a backslash before dirname if
basename does not end with a slash, so we make sure it does.
"""
dirname += "/" if dirname[-1] != "/" else ""
return os.path.join(dirname, basename)
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
"""
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
"""
_iterator = items
if show_progress:
try:
from tqdm.auto import tqdm
return tqdm(items, desc=desc)
except ImportError:
pass
return _iterator
def count_tokens(text: str) -> int:
tokens = globals_helper.tokenizer(text)
return len(tokens)
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
"""
Args:
model_name(str): the model name of the tokenizer.
For instance, fxmarty/tiny-llama-fast-tokenizer.
"""
try:
from transformers import AutoTokenizer
except ImportError:
raise ValueError(
"`transformers` package not found, please run `pip install transformers`"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer.tokenize
def get_cache_dir() -> str:
"""Locate a platform-appropriate cache directory for llama_index,
and create it if it doesn't yet exist.
"""
# User override
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
# Linux, Unix, AIX, etc.
elif os.name == "posix" and sys.platform != "darwin":
path = Path("/tmp/llama_index")
# Mac OS
elif sys.platform == "darwin":
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
# Windows (hopefully)
else:
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
"~\\AppData\\Local"
)
path = Path(local, "llama_index")
if not os.path.exists(path):
os.makedirs(
path, exist_ok=True
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
return str(path)
def add_sync_version(func: Any) -> Any:
"""Decorator for adding sync version of an async function. The sync version
is added as a function attribute to the original function, func.
Args:
func(Any): the async function for which a sync variant will be built.
"""
assert asyncio.iscoroutinefunction(func)
@wraps(func)
def _wrapper(*args: Any, **kwds: Any) -> Any:
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
func.sync = _wrapper
return func
# Sample text from llama_index's readme
SAMPLE_TEXT = """
Context
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
They are pre-trained on large amounts of publicly available data.
How do we best augment LLMs with our own private data?
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
Proposed Solution
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
you build LLM apps. It provides the following tools:
Offers data connectors to ingest your existing data sources and data formats
(APIs, PDFs, docs, SQL, etc.)
Provides ways to structure your data (indices, graphs) so that this data can be
easily used with LLMs.
Provides an advanced retrieval/query interface over your data:
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
Allows easy integrations with your outer application framework
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
LlamaIndex provides tools for both beginner users and advanced users.
Our high-level API allows beginner users to use LlamaIndex to ingest and
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
customize and extend any module (data connectors, indices, retrievers, query engines,
reranking modules), to fit their needs.
"""
_LLAMA_INDEX_COLORS = {
"llama_pink": "38;2;237;90;200",
"llama_blue": "38;2;90;149;237",
"llama_turquoise": "38;2;11;159;203",
"llama_lavender": "38;2;155;135;227",
}
_ANSI_COLORS = {
"red": "31",
"green": "32",
"yellow": "33",
"blue": "34",
"magenta": "35",
"cyan": "36",
"pink": "38;5;200",
}
def get_color_mapping(
items: List[str], use_llama_index_colors: bool = True
) -> Dict[str, str]:
"""
Get a mapping of items to colors.
Args:
items (List[str]): List of items to be mapped to colors.
use_llama_index_colors (bool, optional): Flag to indicate
whether to use LlamaIndex colors or ANSI colors.
Defaults to True.
Returns:
Dict[str, str]: Mapping of items to colors.
"""
if use_llama_index_colors:
color_palette = _LLAMA_INDEX_COLORS
else:
color_palette = _ANSI_COLORS
colors = list(color_palette.keys())
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
def _get_colored_text(text: str, color: str) -> str:
"""
Get the colored version of the input text.
Args:
text (str): Input text.
color (str): Color to be applied to the text.
Returns:
str: Colored version of the input text.
"""
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
if color not in all_colors:
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
color = all_colors[color]
return f"\033[1;3;{color}m{text}\033[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""
Print the text with the specified color.
Args:
text (str): Text to be printed.
color (str, optional): Color to be applied to the text. Supported colors are:
llama_pink, llama_blue, llama_turquoise, llama_lavender,
red, green, yellow, blue, magenta, cyan, pink.
end (str, optional): String appended after the last character of the text.
Returns:
None
"""
text_to_print = _get_colored_text(text, color) if color is not None else text
print(text_to_print, end=end)
def infer_torch_device() -> str:
"""Infer the input to torch.device."""
try:
has_cuda = torch.cuda.is_available()
except NameError:
import torch
has_cuda = torch.cuda.is_available()
if has_cuda:
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def unit_generator(x: Any) -> Generator[Any, None, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
"""A function that returns a generator of a single element.
Args:
x (Any): the element to build yield
Yields:
Any: the single element
"""
yield x

View File

@@ -0,0 +1,43 @@
from pydantic import Field, BaseModel
DEFAULT_CONTEXT_WINDOW = 3900
DEFAULT_NUM_OUTPUTS = 256
class LLMMetadata(BaseModel):
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
"Total number of tokens the model can be input and output for one response."
),
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="Number of tokens the model can output when generating a response.",
)
is_chat_model: bool = Field(
default=False,
description=(
"Set True if the model exposes a chat interface (i.e. can be passed a"
" sequence of messages, rather than text), like OpenAI's"
" /v1/chat/completions endpoint."
),
)
is_function_calling_model: bool = Field(
default=False,
# SEE: https://openai.com/blog/function-calling-and-other-api-updates
description=(
"Set True if the model supports function calling messages, similar to"
" OpenAI's function calling API. For example, converting 'Email Anya to"
" see if she wants to get coffee next Friday' to a function call like"
" `send_email(to: string, body: string)`."
),
)
model_name: str = Field(
default="unknown",
description=(
"The model's name used for logging, testing, and sanity checking. For some"
" models this can be automatically discerned. For other models, like"
" locally loaded models, this must be manually specified."
),
)

254
pilot/common/prompt_util.py Normal file
View File

@@ -0,0 +1,254 @@
"""General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context window
size of an LLM and reserve token space for the prompt template, and the output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM calls
needed), or truncating them so that they fit in a single LLM call.
"""
import logging
from string import Formatter
from typing import Callable, List, Optional, Sequence
from pydantic import Field, PrivateAttr, BaseModel
from llama_index.prompts import BasePromptTemplate
from pilot.common.global_helper import globals_helper
from pilot.common.llm_metadata import LLMMetadata
from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter
DEFAULT_PADDING = 5
DEFAULT_CHUNK_OVERLAP_RATIO = 0.1
DEFAULT_CONTEXT_WINDOW = 3000 # tokens
DEFAULT_NUM_OUTPUTS = 256 # tokens
logger = logging.getLogger(__name__)
class PromptHelper(BaseModel):
"""Prompt helper.
General prompt helper that can help deal with LLM context window token limitations.
At its core, it calculates available context size by starting with the context
window size of an LLM and reserve token space for the prompt template, and the
output.
It provides utility for "repacking" text chunks (retrieved from index) to maximally
make use of the available context window (and thereby reducing the number of LLM
calls needed), or truncating them so that they fit in a single LLM call.
Args:
context_window (int): Context window for the LLM.
num_output (int): Number of outputs for the LLM.
chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size
chunk_size_limit (Optional[int]): Maximum chunk size to use.
tokenizer (Optional[Callable[[str], List]]): Tokenizer to use.
separator (str): Separator for text splitter
"""
context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description="The maximum context size that will get sent to the LLM.",
)
num_output: int = Field(
default=DEFAULT_NUM_OUTPUTS,
description="The amount of token-space to leave in input for generation.",
)
chunk_overlap_ratio: float = Field(
default=DEFAULT_CHUNK_OVERLAP_RATIO,
description="The percentage token amount that each chunk should overlap.",
)
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
separator: str = Field(
default=" ", description="The separator when chunking tokens."
)
_tokenizer: Callable[[str], List] = PrivateAttr()
def __init__(
self,
context_window: int = DEFAULT_CONTEXT_WINDOW,
num_output: int = DEFAULT_NUM_OUTPUTS,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> None:
"""Init params."""
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
# TODO: make configurable
self._tokenizer = tokenizer or globals_helper.tokenizer
super().__init__(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
separator=separator,
)
@classmethod
def from_llm_metadata(
cls,
llm_metadata: LLMMetadata,
chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO,
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
) -> "PromptHelper":
"""Create from llm predictor.
This will autofill values like context_window and num_output.
"""
context_window = llm_metadata.context_window
if llm_metadata.num_output == -1:
num_output = DEFAULT_NUM_OUTPUTS
else:
num_output = llm_metadata.num_output
return cls(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
tokenizer=tokenizer,
separator=separator,
)
@classmethod
def class_name(cls) -> str:
return "PromptHelper"
def _get_available_context_size(self, template: str) -> int:
"""Get available context size.
This is calculated as:
available context window = total context window
- input (partially filled prompt)
- output (room reserved for response)
Notes:
- Available context size is further clamped to be non-negative.
"""
empty_prompt_txt = get_empty_prompt_txt(template)
num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt))
context_size_tokens = (
self.context_window - num_empty_prompt_tokens - self.num_output
)
if context_size_tokens < 0:
raise ValueError(
f"Calculated available context size {context_size_tokens} was"
" not non-negative."
)
return context_size_tokens
def _get_available_chunk_size(
self, prompt_template: str, num_chunks: int = 1, padding: int = 5
) -> int:
"""Get available chunk size.
This is calculated as:
available chunk size = available context window // number_chunks
- padding
Notes:
- By default, we use padding of 5 (to save space for formatting needs).
- Available chunk size is further clamped to chunk_size_limit if specified.
"""
available_context_size = self._get_available_context_size(prompt_template)
result = available_context_size // num_chunks - padding
if self.chunk_size_limit is not None:
result = min(result, self.chunk_size_limit)
return result
def get_text_splitter_given_prompt(
self,
prompt_template: str,
num_chunks: int = 1,
padding: int = DEFAULT_PADDING,
) -> TokenTextSplitter:
"""Get text splitter configured to maximally pack available context window,
taking into account of given prompt, and desired number of chunks.
"""
chunk_size = self._get_available_chunk_size(
prompt_template, num_chunks, padding=padding
)
if chunk_size <= 0:
raise ValueError(f"Chunk size {chunk_size} is not positive.")
chunk_overlap = int(self.chunk_overlap_ratio * chunk_size)
return TokenTextSplitter(
separator=self.separator,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
tokenizer=self._tokenizer,
)
# def truncate(
# self,
# prompt: BasePromptTemplate,
# text_chunks: Sequence[str],
# padding: int = DEFAULT_PADDING,
# ) -> List[str]:
# """Truncate text chunks to fit available context window."""
# text_splitter = self.get_text_splitter_given_prompt(
# prompt,
# num_chunks=len(text_chunks),
# padding=padding,
# )
# return [truncate_text(chunk, text_splitter) for chunk in text_chunks]
def repack(
self,
prompt_template: str,
text_chunks: Sequence[str],
padding: int = DEFAULT_PADDING,
) -> List[str]:
"""Repack text chunks to fit available context window.
This will combine text chunks into consolidated chunks
that more fully "pack" the prompt template given the context_window.
"""
text_splitter = self.get_text_splitter_given_prompt(
prompt_template, padding=padding
)
combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()])
return text_splitter.split_text(combined_str)
def get_empty_prompt_txt(template: str) -> str:
"""Get empty prompt text.
Substitute empty strings in parts of the prompt that have
not yet been filled out. Skip variables that have already
been partially formatted. This is used to compute the initial tokens.
"""
# partial_kargs = prompt.kwargs
partial_kargs = {}
template_vars = get_template_vars(template)
empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs}
all_kwargs = {**partial_kargs, **empty_kwargs}
prompt = template.format(**all_kwargs)
return prompt
def get_template_vars(template_str: str) -> List[str]:
"""Get template variables from a template string."""
variables = []
formatter = Formatter()
for _, variable_name, _, _ in formatter.parse(template_str):
if variable_name:
variables.append(variable_name)
return variables

View File

@@ -0,0 +1,82 @@
from typing import Callable, List
def split_text_keep_separator(text: str, separator: str) -> List[str]:
"""Split text with separator and keep the separator at the end of each split."""
parts = text.split(separator)
result = [separator + s if i > 0 else s for i, s in enumerate(parts)]
return [s for s in result if s]
def split_by_sep(sep: str, keep_sep: bool = True) -> Callable[[str], List[str]]:
"""Split text by separator."""
if keep_sep:
return lambda text: split_text_keep_separator(text, sep)
else:
return lambda text: text.split(sep)
def split_by_char() -> Callable[[str], List[str]]:
"""Split text by character."""
return lambda text: list(text)
def split_by_sentence_tokenizer() -> Callable[[str], List[str]]:
import os
import nltk
from llama_index.utils import get_cache_dir
cache_dir = get_cache_dir()
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
# update nltk path for nltk so that it finds the data
if nltk_data_dir not in nltk.data.path:
nltk.data.path.append(nltk_data_dir)
try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", download_dir=nltk_data_dir)
tokenizer = nltk.tokenize.PunktSentenceTokenizer()
# get the spans and then return the sentences
# using the start index of each span
# instead of using end, use the start of the next span if available
def split(text: str) -> List[str]:
spans = list(tokenizer.span_tokenize(text))
sentences = []
for i, span in enumerate(spans):
start = span[0]
if i < len(spans) - 1:
end = spans[i + 1][0]
else:
end = len(text)
sentences.append(text[start:end])
return sentences
return split
def split_by_regex(regex: str) -> Callable[[str], List[str]]:
"""Split text by regex."""
import re
return lambda text: re.findall(regex, text)
def split_by_phrase_regex() -> Callable[[str], List[str]]:
"""Split text by phrase regex.
This regular expression will split the sentences into phrases,
where each phrase is a sequence of one or more non-comma,
non-period, and non-semicolon characters, followed by an optional comma,
period, or semicolon. The regular expression will also capture the
delimiters themselves as separate items in the list of phrases.
"""
regex = "[^,.;。]+[,.;。]?"
return split_by_regex(regex)

View File

@@ -0,0 +1,184 @@
"""Token splitter."""
from typing import Callable, List, Optional
from pydantic import Field, PrivateAttr, BaseModel
from pilot.common.global_helper import globals_helper
from pilot.embedding_engine.loader.splitter_utils import split_by_sep, split_by_char
DEFAULT_METADATA_FORMAT_LEN = 2
DEFAULT_CHUNK_OVERLAP = 20
DEFAULT_CHUNK_SIZE = 1024
class TokenTextSplitter(BaseModel):
"""Implementation of splitting text that looks at word tokens."""
chunk_size: int = Field(
default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk."
)
chunk_overlap: int = Field(
default=DEFAULT_CHUNK_OVERLAP,
description="The token overlap of each chunk when splitting.",
)
separator: str = Field(
default=" ", description="Default separator for splitting into words"
)
backup_separators: List = Field(
default_factory=list, description="Additional separators for splitting."
)
# callback_manager: CallbackManager = Field(
# default_factory=CallbackManager, exclude=True
# )
tokenizer: Callable = Field(
default_factory=globals_helper.tokenizer, # type: ignore
description="Tokenizer for splitting words into tokens.",
exclude=True,
)
_split_fns: List[Callable] = PrivateAttr()
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
tokenizer: Optional[Callable] = None,
# callback_manager: Optional[CallbackManager] = None,
separator: str = " ",
backup_separators: Optional[List[str]] = ["\n"],
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
# callback_manager = callback_manager or CallbackManager([])
tokenizer = tokenizer or globals_helper.tokenizer
all_seps = [separator] + (backup_separators or [])
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
backup_separators=backup_separators,
# callback_manager=callback_manager,
tokenizer=tokenizer,
)
@classmethod
def class_name(cls) -> str:
return "TokenTextSplitter"
def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]:
"""Split text into chunks, reserving space required for metadata str."""
metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN
effective_chunk_size = self.chunk_size - metadata_len
if effective_chunk_size <= 0:
raise ValueError(
f"Metadata length ({metadata_len}) is longer than chunk size "
f"({self.chunk_size}). Consider increasing the chunk size or "
"decreasing the size of your metadata to avoid this."
)
elif effective_chunk_size < 50:
print(
f"Metadata length ({metadata_len}) is close to chunk size "
f"({self.chunk_size}). Resulting chunks are less than 50 tokens. "
"Consider increasing the chunk size or decreasing the size of "
"your metadata to avoid this.",
flush=True,
)
return self._split_text(text, chunk_size=effective_chunk_size)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks."""
return self._split_text(text, chunk_size=self.chunk_size)
def _split_text(self, text: str, chunk_size: int) -> List[str]:
"""Split text into chunks up to chunk_size."""
if text == "":
return []
splits = self._split(text, chunk_size)
chunks = self._merge(splits, chunk_size)
return chunks
def _split(self, text: str, chunk_size: int) -> List[str]:
"""Break text into splits that are smaller than chunk size.
The order of splitting is:
1. split by separator
2. split by backup separators (if any)
3. split by characters
NOTE: the splits contain the separators.
"""
if len(self.tokenizer(text)) <= chunk_size:
return [text]
for split_fn in self._split_fns:
splits = split_fn(text)
if len(splits) > 1:
break
new_splits = []
for split in splits:
split_len = len(self.tokenizer(split))
if split_len <= chunk_size:
new_splits.append(split)
else:
# recursively split
new_splits.extend(self._split(split, chunk_size=chunk_size))
return new_splits
def _merge(self, splits: List[str], chunk_size: int) -> List[str]:
"""Merge splits into chunks.
The high-level idea is to keep adding splits to a chunk until we
exceed the chunk size, then we start a new chunk with overlap.
When we start a new chunk, we pop off the first element of the previous
chunk until the total length is less than the chunk size.
"""
chunks: List[str] = []
cur_chunk: List[str] = []
cur_len = 0
for split in splits:
split_len = len(self.tokenizer(split))
if split_len > chunk_size:
print(
f"Got a split of size {split_len}, ",
f"larger than chunk size {chunk_size}.",
)
# if we exceed the chunk size after adding the new split, then
# we need to end the current chunk and start a new one
if cur_len + split_len > chunk_size:
# end the previous chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
# start a new chunk with overlap
# keep popping off the first element of the previous chunk until:
# 1. the current chunk length is less than chunk overlap
# 2. the total length is less than chunk size
while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size:
# pop off the first element
first_chunk = cur_chunk.pop(0)
cur_len -= len(self.tokenizer(first_chunk))
cur_chunk.append(split)
cur_len += split_len
# handle the last chunk
chunk = "".join(cur_chunk).strip()
if chunk:
chunks.append(chunk)
return chunks

View File

@@ -12,11 +12,12 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n再完善一下原来的总结。\n回答的时候最好按照1.2.3.进行总结"""
_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n根据你之前推理的内容进行最终的总结,总结的时候可以详细点,回答的时候最好按照1.2.3.进行总结."""
_DEFAULT_TEMPLATE_EN = """
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary.
\nWhen answering, it is best to summarize according to points 1.2.3.
We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.
\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1, 2, and 3. and etc.
"""
_DEFAULT_TEMPLATE = (
@@ -27,7 +28,7 @@ PROMPT_RESPONSE = """"""
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
PROMPT_NEED_NEED_STREAM_OUT = True
prompt = PromptTemplate(
template_scene=ChatScene.ExtractRefineSummary.value(),

View File

@@ -21,7 +21,8 @@ class ExtractSummary(BaseChat):
chat_param=chat_param,
)
self.user_input = chat_param["current_user_input"]
# self.user_input = chat_param["current_user_input"]
self.user_input = chat_param["select_param"]
# self.extract_mode = chat_param["select_param"]
def generate_input_values(self):

View File

@@ -11,15 +11,15 @@ CFG = Config()
PROMPT_SCENE_DEFINE = """"""
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结:
_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结:
{context}
回答的时候最好按照1.2.3.点进行总结
答案尽量精确和简单,不要过长长度控制在100字左右
"""
_DEFAULT_TEMPLATE_EN = """
Write a summary of the following context:
Write a quick summary of the following context:
{context}
When answering, it is best to summarize according to points 1.2.3.
the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters.
"""
_DEFAULT_TEMPLATE = (

View File

@@ -10,6 +10,7 @@ from pilot.configs.model_config import (
EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.openapi.api_v1.api_v1 import no_stream_generator, stream_generator
from pilot.openapi.api_view_model import Result
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
@@ -25,9 +26,11 @@ from pilot.server.knowledge.request.request import (
DocumentQueryRequest,
SpaceArgumentRequest,
EntityExtractRequest,
DocumentSummaryRequest,
)
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.utils.tracer import root_tracer, SpanType
logger = logging.getLogger(__name__)
@@ -201,6 +204,45 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest):
return {"response": res}
@router.post("/knowledge/document/summary")
async def document_summary(request: DocumentSummaryRequest):
print(f"/document/summary params: {request}")
try:
with root_tracer.start_span(
"get_chat_instance", span_type=SpanType.CHAT, metadata=request
):
chat = await knowledge_space_service.document_summary(request=request)
headers = {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
}
from starlette.responses import StreamingResponse
if not chat.prompt_template.stream_out:
return StreamingResponse(
no_stream_generator(chat),
headers=headers,
media_type="text/event-stream",
)
else:
return StreamingResponse(
stream_generator(chat, False, request.model_name),
headers=headers,
media_type="text/plain",
)
# return Result.succ(
# knowledge_space_service.create_knowledge_document(
# space=space_name, request=request
# )
# )
# return Result.succ([])
except Exception as e:
return Result.faild(code="E000X", msg=f"document add error {e}")
@router.post("/knowledge/entity/extract")
async def entity_extract(request: EntityExtractRequest):
logger.info(f"Received params: {request}")

View File

@@ -108,6 +108,14 @@ class SpaceArgumentRequest(BaseModel):
argument: str
class DocumentSummaryRequest(BaseModel):
"""Sync request"""
"""doc_ids: doc ids"""
doc_id: int
model_name: str
class EntityExtractRequest(BaseModel):
"""argument: argument"""

View File

@@ -10,7 +10,7 @@ from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
@@ -31,6 +31,7 @@ from pilot.server.knowledge.request.request import (
ChunkQueryRequest,
SpaceArgumentRequest,
DocumentSyncRequest,
DocumentSummaryRequest,
)
from enum import Enum
@@ -303,7 +304,30 @@ class KnowledgeService:
]
document_chunk_dao.create_documents_chunks(chunk_entities)
return True
return doc.id
async def document_summary(self, request: DocumentSummaryRequest):
"""get document summary
Args:
- request: DocumentSummaryRequest
"""
doc_query = KnowledgeDocumentEntity(id=request.doc_id)
documents = knowledge_document_dao.get_documents(doc_query)
if len(documents) != 1:
raise Exception(f"can not found document for {request.doc_id}")
document = documents[0]
query = DocumentChunkEntity(
document_id=request.doc_id,
)
chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100)
if len(chunks) == 0:
raise Exception(f"can not found chunks for {request.doc_id}")
from langchain.schema import Document
chunk_docs = [Document(page_content=chunk.content) for chunk in chunks]
return await self.async_document_summary(
model_name=request.model_name, chunk_docs=chunk_docs, doc=document
)
def update_knowledge_space(
self, space_id: int, space_request: KnowledgeSpaceRequest
@@ -417,30 +441,25 @@ class KnowledgeService:
logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}")
return knowledge_document_dao.update_knowledge_document(doc)
def async_document_summary(self, chunk_docs, doc):
async def async_document_summary(self, model_name, chunk_docs, doc):
"""async document extract summary
Args:
- model_name: str
- chunk_docs: List[Document]
- doc: KnowledgeDocumentEntity
"""
from llama_index import PromptHelper
from llama_index.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
texts = [doc.page_content for doc in chunk_docs]
prompt_helper = PromptHelper(context_window=2000)
from pilot.common.prompt_util import PromptHelper
texts = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts
)
prompt_helper = PromptHelper()
from pilot.scene.chat_knowledge.summary.prompt import prompt
texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts)
logger.info(
f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary"
)
summary = self._mapreduce_extract_summary(texts)
print(f"final summary:{summary}")
doc.summary = summary
return knowledge_document_dao.update_knowledge_document(doc)
summary = await self._mapreduce_extract_summary(texts, model_name, 10, 3)
return await self._llm_extract_summary(summary, model_name)
def async_doc_embedding(self, client, chunk_docs, doc):
"""async document embedding into vector db
@@ -506,30 +525,37 @@ class KnowledgeService:
return json.loads(spaces[0].context)
return None
def _llm_extract_summary(self, doc: str):
async def _llm_extract_summary(self, doc: str, model_name: str = None):
"""Extract triplets from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
import uuid
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"current_user_input": "",
"select_param": doc,
"model_name": self.model_name,
"model_name": model_name,
}
from pilot.common.chat_util import run_async_tasks
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
from pilot.openapi.api_v1.api_v1 import CHAT_FACTORY
summary_iters = run_async_tasks(
[
llm_chat_response_nostream(
ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param}
chat = await blocking_func_to_async(
executor,
CHAT_FACTORY.get_implementation,
ChatScene.ExtractRefineSummary.value(),
**{"chat_param": chat_param},
)
]
)
return summary_iters[0]
return chat
def _mapreduce_extract_summary(self, docs):
async def _mapreduce_extract_summary(
self,
docs,
model_name: str = None,
max_iteration: int = 5,
concurrency_limit: int = None,
):
"""Extract summary by mapreduce mode
map -> multi async thread generate summary
reduce -> merge the summaries by map process
@@ -541,18 +567,16 @@ class KnowledgeService:
import uuid
tasks = []
max_iteration = 5
if len(docs) == 1:
summary = self._llm_extract_summary(doc=docs[0])
return summary
return docs[0]
else:
max_iteration = max_iteration if len(docs) > max_iteration else len(docs)
for doc in docs[0:max_iteration]:
chat_param = {
"chat_session_id": uuid.uuid1(),
"current_user_input": doc,
"select_param": "summary",
"model_name": self.model_name,
"current_user_input": "",
"select_param": doc,
"model_name": model_name,
}
tasks.append(
llm_chat_response_nostream(
@@ -561,14 +585,22 @@ class KnowledgeService:
)
from pilot.common.chat_util import run_async_tasks
summary_iters = run_async_tasks(tasks)
summary_iters = await run_async_tasks(
tasks=tasks, concurrency_limit=concurrency_limit
)
summary_iters = list(
filter(
lambda content: "LLMServer Generate Error" not in content,
summary_iters,
)
)
from pilot.common.prompt_util import PromptHelper
from llama_index.prompts.default_prompt_selectors import (
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL,
)
from pilot.scene.chat_knowledge.summary.prompt import prompt
prompt_helper = PromptHelper(context_window=2500)
prompt_helper = PromptHelper()
summary_iters = prompt_helper.repack(
prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters
prompt_template=prompt.template, text_chunks=summary_iters
)
return await self._mapreduce_extract_summary(
summary_iters, model_name, max_iteration, concurrency_limit
)
return self._mapreduce_extract_summary(summary_iters)