diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py index ae0ce73ed..6cb8f1b53 100644 --- a/pilot/common/chat_util.py +++ b/pilot/common/chat_util.py @@ -21,35 +21,27 @@ async def llm_chat_response(chat_scene: str, **chat_param): return chat.stream_call() -def run_async_tasks( +async def run_async_tasks( tasks: List[Coroutine], - show_progress: bool = False, - progress_bar_desc: str = "Running async tasks", + concurrency_limit: int = None, ) -> List[Any]: """Run a list of async tasks.""" - tasks_to_execute: List[Any] = tasks - if show_progress: - try: - import nest_asyncio - from tqdm.asyncio import tqdm - - nest_asyncio.apply() - loop = asyncio.get_event_loop() - - async def _tqdm_gather() -> List[Any]: - return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc) - - tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather()) - return tqdm_outputs - # run the operation w/o tqdm on hitting a fatal - # may occur in some environments where tqdm.asyncio - # is not supported - except Exception: - pass async def _gather() -> List[Any]: - return await asyncio.gather(*tasks_to_execute) + if concurrency_limit: + semaphore = asyncio.Semaphore(concurrency_limit) - outputs: List[Any] = asyncio.run(_gather()) - return outputs + async def _execute_task(task): + async with semaphore: + return await task + + # Execute tasks with semaphore limit + return await asyncio.gather( + *[_execute_task(task) for task in tasks_to_execute] + ) + else: + return await asyncio.gather(*tasks_to_execute) + + # outputs: List[Any] = asyncio.run(_gather()) + return await _gather() diff --git a/pilot/common/global_helper.py b/pilot/common/global_helper.py new file mode 100644 index 000000000..d189d1356 --- /dev/null +++ b/pilot/common/global_helper.py @@ -0,0 +1,448 @@ +"""General utils functions.""" + +import asyncio +import os +import random +import sys +import time +import traceback +import uuid +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial, wraps +from itertools import islice +from pathlib import Path +from typing import ( + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Type, + Union, + cast, +) + + +class GlobalsHelper: + """Helper to retrieve globals. + + Helpful for global caching of certain variables that can be expensive to load. + (e.g. tokenization) + + """ + + _tokenizer: Optional[Callable[[str], List]] = None + _stopwords: Optional[List[str]] = None + + @property + def tokenizer(self) -> Callable[[str], List]: + """Get tokenizer.""" + if self._tokenizer is None: + tiktoken_import_err = ( + "`tiktoken` package not found, please run `pip install tiktoken`" + ) + try: + import tiktoken + except ImportError: + raise ImportError(tiktoken_import_err) + enc = tiktoken.get_encoding("gpt2") + self._tokenizer = cast(Callable[[str], List], enc.encode) + self._tokenizer = partial(self._tokenizer, allowed_special="all") + return self._tokenizer # type: ignore + + @property + def stopwords(self) -> List[str]: + """Get stopwords.""" + if self._stopwords is None: + try: + import nltk + from nltk.corpus import stopwords + except ImportError: + raise ImportError( + "`nltk` package not found, please run `pip install nltk`" + ) + + from llama_index.utils import get_cache_dir + + cache_dir = get_cache_dir() + nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir) + + # update nltk path for nltk so that it finds the data + if nltk_data_dir not in nltk.data.path: + nltk.data.path.append(nltk_data_dir) + + try: + nltk.data.find("corpora/stopwords") + except LookupError: + nltk.download("stopwords", download_dir=nltk_data_dir) + self._stopwords = stopwords.words("english") + return self._stopwords + + +globals_helper = GlobalsHelper() + + +def get_new_id(d: Set) -> str: + """Get a new ID.""" + while True: + new_id = str(uuid.uuid4()) + if new_id not in d: + break + return new_id + + +def get_new_int_id(d: Set) -> int: + """Get a new integer ID.""" + while True: + new_id = random.randint(0, sys.maxsize) + if new_id not in d: + break + return new_id + + +@contextmanager +def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator: + """Temporary setter. + + Utility class for setting a temporary value for an attribute on a class. + Taken from: https://tinyurl.com/2p89xymh + + """ + prev_values = {k: getattr(obj, k) for k in kwargs} + for k, v in kwargs.items(): + setattr(obj, k, v) + try: + yield + finally: + for k, v in prev_values.items(): + setattr(obj, k, v) + + +@dataclass +class ErrorToRetry: + """Exception types that should be retried. + + Args: + exception_cls (Type[Exception]): Class of exception. + check_fn (Optional[Callable[[Any]], bool]]): + A function that takes an exception instance as input and returns + whether to retry. + + """ + + exception_cls: Type[Exception] + check_fn: Optional[Callable[[Any], bool]] = None + + +def retry_on_exceptions_with_backoff( + lambda_fn: Callable, + errors_to_retry: List[ErrorToRetry], + max_tries: int = 10, + min_backoff_secs: float = 0.5, + max_backoff_secs: float = 60.0, +) -> Any: + """Execute lambda function with retries and exponential backoff. + + Args: + lambda_fn (Callable): Function to be called and output we want. + errors_to_retry (List[ErrorToRetry]): List of errors to retry. + At least one needs to be provided. + max_tries (int): Maximum number of tries, including the first. Defaults to 10. + min_backoff_secs (float): Minimum amount of backoff time between attempts. + Defaults to 0.5. + max_backoff_secs (float): Maximum amount of backoff time between attempts. + Defaults to 60. + + """ + if not errors_to_retry: + raise ValueError("At least one error to retry needs to be provided") + + error_checks = { + error_to_retry.exception_cls: error_to_retry.check_fn + for error_to_retry in errors_to_retry + } + exception_class_tuples = tuple(error_checks.keys()) + + backoff_secs = min_backoff_secs + tries = 0 + + while True: + try: + return lambda_fn() + except exception_class_tuples as e: + traceback.print_exc() + tries += 1 + if tries >= max_tries: + raise + check_fn = error_checks.get(e.__class__) + if check_fn and not check_fn(e): + raise + time.sleep(backoff_secs) + backoff_secs = min(backoff_secs * 2, max_backoff_secs) + + +def truncate_text(text: str, max_length: int) -> str: + """Truncate text to a maximum length.""" + if len(text) <= max_length: + return text + return text[: max_length - 3] + "..." + + +def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: + """Iterate over an iterable in batches. + + >>> list(iter_batch([1,2,3,4,5], 3)) + [[1, 2, 3], [4, 5]] + """ + source_iter = iter(iterable) + while source_iter: + b = list(islice(source_iter, size)) + if len(b) == 0: + break + yield b + + +def concat_dirs(dirname: str, basename: str) -> str: + """ + Append basename to dirname, avoiding backslashes when running on windows. + + os.path.join(dirname, basename) will add a backslash before dirname if + basename does not end with a slash, so we make sure it does. + """ + dirname += "/" if dirname[-1] != "/" else "" + return os.path.join(dirname, basename) + + +def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable: + """ + Optionally get a tqdm iterable. Ensures tqdm.auto is used. + """ + _iterator = items + if show_progress: + try: + from tqdm.auto import tqdm + + return tqdm(items, desc=desc) + except ImportError: + pass + return _iterator + + +def count_tokens(text: str) -> int: + tokens = globals_helper.tokenizer(text) + return len(tokens) + + +def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]: + """ + Args: + model_name(str): the model name of the tokenizer. + For instance, fxmarty/tiny-llama-fast-tokenizer. + """ + try: + from transformers import AutoTokenizer + except ImportError: + raise ValueError( + "`transformers` package not found, please run `pip install transformers`" + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + return tokenizer.tokenize + + +def get_cache_dir() -> str: + """Locate a platform-appropriate cache directory for llama_index, + and create it if it doesn't yet exist. + """ + # User override + if "LLAMA_INDEX_CACHE_DIR" in os.environ: + path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"]) + + # Linux, Unix, AIX, etc. + elif os.name == "posix" and sys.platform != "darwin": + path = Path("/tmp/llama_index") + + # Mac OS + elif sys.platform == "darwin": + path = Path(os.path.expanduser("~"), "Library/Caches/llama_index") + + # Windows (hopefully) + else: + local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser( + "~\\AppData\\Local" + ) + path = Path(local, "llama_index") + + if not os.path.exists(path): + os.makedirs( + path, exist_ok=True + ) # prevents https://github.com/jerryjliu/llama_index/issues/7362 + return str(path) + + +def add_sync_version(func: Any) -> Any: + """Decorator for adding sync version of an async function. The sync version + is added as a function attribute to the original function, func. + + Args: + func(Any): the async function for which a sync variant will be built. + """ + assert asyncio.iscoroutinefunction(func) + + @wraps(func) + def _wrapper(*args: Any, **kwds: Any) -> Any: + return asyncio.get_event_loop().run_until_complete(func(*args, **kwds)) + + func.sync = _wrapper + return func + + +# Sample text from llama_index's readme +SAMPLE_TEXT = """ +Context +LLMs are a phenomenal piece of technology for knowledge generation and reasoning. +They are pre-trained on large amounts of publicly available data. +How do we best augment LLMs with our own private data? +We need a comprehensive toolkit to help perform this data augmentation for LLMs. + +Proposed Solution +That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help +you build LLM apps. It provides the following tools: + +Offers data connectors to ingest your existing data sources and data formats +(APIs, PDFs, docs, SQL, etc.) +Provides ways to structure your data (indices, graphs) so that this data can be +easily used with LLMs. +Provides an advanced retrieval/query interface over your data: +Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output. +Allows easy integrations with your outer application framework +(e.g. with LangChain, Flask, Docker, ChatGPT, anything else). +LlamaIndex provides tools for both beginner users and advanced users. +Our high-level API allows beginner users to use LlamaIndex to ingest and +query their data in 5 lines of code. Our lower-level APIs allow advanced users to +customize and extend any module (data connectors, indices, retrievers, query engines, +reranking modules), to fit their needs. +""" + +_LLAMA_INDEX_COLORS = { + "llama_pink": "38;2;237;90;200", + "llama_blue": "38;2;90;149;237", + "llama_turquoise": "38;2;11;159;203", + "llama_lavender": "38;2;155;135;227", +} + +_ANSI_COLORS = { + "red": "31", + "green": "32", + "yellow": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "pink": "38;5;200", +} + + +def get_color_mapping( + items: List[str], use_llama_index_colors: bool = True +) -> Dict[str, str]: + """ + Get a mapping of items to colors. + + Args: + items (List[str]): List of items to be mapped to colors. + use_llama_index_colors (bool, optional): Flag to indicate + whether to use LlamaIndex colors or ANSI colors. + Defaults to True. + + Returns: + Dict[str, str]: Mapping of items to colors. + """ + if use_llama_index_colors: + color_palette = _LLAMA_INDEX_COLORS + else: + color_palette = _ANSI_COLORS + + colors = list(color_palette.keys()) + return {item: colors[i % len(colors)] for i, item in enumerate(items)} + + +def _get_colored_text(text: str, color: str) -> str: + """ + Get the colored version of the input text. + + Args: + text (str): Input text. + color (str): Color to be applied to the text. + + Returns: + str: Colored version of the input text. + """ + all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS} + + if color not in all_colors: + return f"\033[1;3m{text}\033[0m" # just bolded and italicized + + color = all_colors[color] + + return f"\033[1;3;{color}m{text}\033[0m" + + +def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: + """ + Print the text with the specified color. + + Args: + text (str): Text to be printed. + color (str, optional): Color to be applied to the text. Supported colors are: + llama_pink, llama_blue, llama_turquoise, llama_lavender, + red, green, yellow, blue, magenta, cyan, pink. + end (str, optional): String appended after the last character of the text. + + Returns: + None + """ + text_to_print = _get_colored_text(text, color) if color is not None else text + print(text_to_print, end=end) + + +def infer_torch_device() -> str: + """Infer the input to torch.device.""" + try: + has_cuda = torch.cuda.is_available() + except NameError: + import torch + + has_cuda = torch.cuda.is_available() + if has_cuda: + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + return "cpu" + + +def unit_generator(x: Any) -> Generator[Any, None, None]: + """A function that returns a generator of a single element. + + Args: + x (Any): the element to build yield + + Yields: + Any: the single element + """ + yield x + + +async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]: + """A function that returns a generator of a single element. + + Args: + x (Any): the element to build yield + + Yields: + Any: the single element + """ + yield x diff --git a/pilot/common/llm_metadata.py b/pilot/common/llm_metadata.py new file mode 100644 index 000000000..73d5bc0fa --- /dev/null +++ b/pilot/common/llm_metadata.py @@ -0,0 +1,43 @@ +from pydantic import Field, BaseModel + +DEFAULT_CONTEXT_WINDOW = 3900 +DEFAULT_NUM_OUTPUTS = 256 + + +class LLMMetadata(BaseModel): + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=( + "Total number of tokens the model can be input and output for one response." + ), + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="Number of tokens the model can output when generating a response.", + ) + is_chat_model: bool = Field( + default=False, + description=( + "Set True if the model exposes a chat interface (i.e. can be passed a" + " sequence of messages, rather than text), like OpenAI's" + " /v1/chat/completions endpoint." + ), + ) + is_function_calling_model: bool = Field( + default=False, + # SEE: https://openai.com/blog/function-calling-and-other-api-updates + description=( + "Set True if the model supports function calling messages, similar to" + " OpenAI's function calling API. For example, converting 'Email Anya to" + " see if she wants to get coffee next Friday' to a function call like" + " `send_email(to: string, body: string)`." + ), + ) + model_name: str = Field( + default="unknown", + description=( + "The model's name used for logging, testing, and sanity checking. For some" + " models this can be automatically discerned. For other models, like" + " locally loaded models, this must be manually specified." + ), + ) diff --git a/pilot/common/prompt_util.py b/pilot/common/prompt_util.py new file mode 100644 index 000000000..525536164 --- /dev/null +++ b/pilot/common/prompt_util.py @@ -0,0 +1,254 @@ +"""General prompt helper that can help deal with LLM context window token limitations. + +At its core, it calculates available context size by starting with the context window +size of an LLM and reserve token space for the prompt template, and the output. + +It provides utility for "repacking" text chunks (retrieved from index) to maximally +make use of the available context window (and thereby reducing the number of LLM calls +needed), or truncating them so that they fit in a single LLM call. +""" + +import logging +from string import Formatter +from typing import Callable, List, Optional, Sequence + +from pydantic import Field, PrivateAttr, BaseModel +from llama_index.prompts import BasePromptTemplate + +from pilot.common.global_helper import globals_helper +from pilot.common.llm_metadata import LLMMetadata +from pilot.embedding_engine.loader.token_splitter import TokenTextSplitter + +DEFAULT_PADDING = 5 +DEFAULT_CHUNK_OVERLAP_RATIO = 0.1 + +DEFAULT_CONTEXT_WINDOW = 3000 # tokens +DEFAULT_NUM_OUTPUTS = 256 # tokens + +logger = logging.getLogger(__name__) + + +class PromptHelper(BaseModel): + """Prompt helper. + + General prompt helper that can help deal with LLM context window token limitations. + + At its core, it calculates available context size by starting with the context + window size of an LLM and reserve token space for the prompt template, and the + output. + + It provides utility for "repacking" text chunks (retrieved from index) to maximally + make use of the available context window (and thereby reducing the number of LLM + calls needed), or truncating them so that they fit in a single LLM call. + + Args: + context_window (int): Context window for the LLM. + num_output (int): Number of outputs for the LLM. + chunk_overlap_ratio (float): Chunk overlap as a ratio of chunk size + chunk_size_limit (Optional[int]): Maximum chunk size to use. + tokenizer (Optional[Callable[[str], List]]): Tokenizer to use. + separator (str): Separator for text splitter + + """ + + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum context size that will get sent to the LLM.", + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The amount of token-space to leave in input for generation.", + ) + chunk_overlap_ratio: float = Field( + default=DEFAULT_CHUNK_OVERLAP_RATIO, + description="The percentage token amount that each chunk should overlap.", + ) + chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.") + separator: str = Field( + default=" ", description="The separator when chunking tokens." + ) + + _tokenizer: Callable[[str], List] = PrivateAttr() + + def __init__( + self, + context_window: int = DEFAULT_CONTEXT_WINDOW, + num_output: int = DEFAULT_NUM_OUTPUTS, + chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO, + chunk_size_limit: Optional[int] = None, + tokenizer: Optional[Callable[[str], List]] = None, + separator: str = " ", + ) -> None: + """Init params.""" + if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0: + raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.") + + # TODO: make configurable + self._tokenizer = tokenizer or globals_helper.tokenizer + + super().__init__( + context_window=context_window, + num_output=num_output, + chunk_overlap_ratio=chunk_overlap_ratio, + chunk_size_limit=chunk_size_limit, + separator=separator, + ) + + @classmethod + def from_llm_metadata( + cls, + llm_metadata: LLMMetadata, + chunk_overlap_ratio: float = DEFAULT_CHUNK_OVERLAP_RATIO, + chunk_size_limit: Optional[int] = None, + tokenizer: Optional[Callable[[str], List]] = None, + separator: str = " ", + ) -> "PromptHelper": + """Create from llm predictor. + + This will autofill values like context_window and num_output. + + """ + context_window = llm_metadata.context_window + if llm_metadata.num_output == -1: + num_output = DEFAULT_NUM_OUTPUTS + else: + num_output = llm_metadata.num_output + + return cls( + context_window=context_window, + num_output=num_output, + chunk_overlap_ratio=chunk_overlap_ratio, + chunk_size_limit=chunk_size_limit, + tokenizer=tokenizer, + separator=separator, + ) + + @classmethod + def class_name(cls) -> str: + return "PromptHelper" + + def _get_available_context_size(self, template: str) -> int: + """Get available context size. + + This is calculated as: + available context window = total context window + - input (partially filled prompt) + - output (room reserved for response) + + Notes: + - Available context size is further clamped to be non-negative. + """ + empty_prompt_txt = get_empty_prompt_txt(template) + num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt)) + context_size_tokens = ( + self.context_window - num_empty_prompt_tokens - self.num_output + ) + if context_size_tokens < 0: + raise ValueError( + f"Calculated available context size {context_size_tokens} was" + " not non-negative." + ) + return context_size_tokens + + def _get_available_chunk_size( + self, prompt_template: str, num_chunks: int = 1, padding: int = 5 + ) -> int: + """Get available chunk size. + + This is calculated as: + available chunk size = available context window // number_chunks + - padding + + Notes: + - By default, we use padding of 5 (to save space for formatting needs). + - Available chunk size is further clamped to chunk_size_limit if specified. + """ + available_context_size = self._get_available_context_size(prompt_template) + result = available_context_size // num_chunks - padding + if self.chunk_size_limit is not None: + result = min(result, self.chunk_size_limit) + return result + + def get_text_splitter_given_prompt( + self, + prompt_template: str, + num_chunks: int = 1, + padding: int = DEFAULT_PADDING, + ) -> TokenTextSplitter: + """Get text splitter configured to maximally pack available context window, + taking into account of given prompt, and desired number of chunks. + """ + chunk_size = self._get_available_chunk_size( + prompt_template, num_chunks, padding=padding + ) + if chunk_size <= 0: + raise ValueError(f"Chunk size {chunk_size} is not positive.") + chunk_overlap = int(self.chunk_overlap_ratio * chunk_size) + return TokenTextSplitter( + separator=self.separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + tokenizer=self._tokenizer, + ) + + # def truncate( + # self, + # prompt: BasePromptTemplate, + # text_chunks: Sequence[str], + # padding: int = DEFAULT_PADDING, + # ) -> List[str]: + # """Truncate text chunks to fit available context window.""" + # text_splitter = self.get_text_splitter_given_prompt( + # prompt, + # num_chunks=len(text_chunks), + # padding=padding, + # ) + # return [truncate_text(chunk, text_splitter) for chunk in text_chunks] + + def repack( + self, + prompt_template: str, + text_chunks: Sequence[str], + padding: int = DEFAULT_PADDING, + ) -> List[str]: + """Repack text chunks to fit available context window. + + This will combine text chunks into consolidated chunks + that more fully "pack" the prompt template given the context_window. + + """ + text_splitter = self.get_text_splitter_given_prompt( + prompt_template, padding=padding + ) + combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()]) + return text_splitter.split_text(combined_str) + + +def get_empty_prompt_txt(template: str) -> str: + """Get empty prompt text. + + Substitute empty strings in parts of the prompt that have + not yet been filled out. Skip variables that have already + been partially formatted. This is used to compute the initial tokens. + + """ + # partial_kargs = prompt.kwargs + + partial_kargs = {} + template_vars = get_template_vars(template) + empty_kwargs = {v: "" for v in template_vars if v not in partial_kargs} + all_kwargs = {**partial_kargs, **empty_kwargs} + prompt = template.format(**all_kwargs) + return prompt + + +def get_template_vars(template_str: str) -> List[str]: + """Get template variables from a template string.""" + variables = [] + formatter = Formatter() + + for _, variable_name, _, _ in formatter.parse(template_str): + if variable_name: + variables.append(variable_name) + + return variables diff --git a/pilot/embedding_engine/loader/splitter_utils.py b/pilot/embedding_engine/loader/splitter_utils.py new file mode 100644 index 000000000..06fe6920a --- /dev/null +++ b/pilot/embedding_engine/loader/splitter_utils.py @@ -0,0 +1,82 @@ +from typing import Callable, List + + +def split_text_keep_separator(text: str, separator: str) -> List[str]: + """Split text with separator and keep the separator at the end of each split.""" + parts = text.split(separator) + result = [separator + s if i > 0 else s for i, s in enumerate(parts)] + return [s for s in result if s] + + +def split_by_sep(sep: str, keep_sep: bool = True) -> Callable[[str], List[str]]: + """Split text by separator.""" + if keep_sep: + return lambda text: split_text_keep_separator(text, sep) + else: + return lambda text: text.split(sep) + + +def split_by_char() -> Callable[[str], List[str]]: + """Split text by character.""" + return lambda text: list(text) + + +def split_by_sentence_tokenizer() -> Callable[[str], List[str]]: + import os + + import nltk + + from llama_index.utils import get_cache_dir + + cache_dir = get_cache_dir() + nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir) + + # update nltk path for nltk so that it finds the data + if nltk_data_dir not in nltk.data.path: + nltk.data.path.append(nltk_data_dir) + + try: + nltk.data.find("tokenizers/punkt") + except LookupError: + nltk.download("punkt", download_dir=nltk_data_dir) + + tokenizer = nltk.tokenize.PunktSentenceTokenizer() + + # get the spans and then return the sentences + # using the start index of each span + # instead of using end, use the start of the next span if available + def split(text: str) -> List[str]: + spans = list(tokenizer.span_tokenize(text)) + sentences = [] + for i, span in enumerate(spans): + start = span[0] + if i < len(spans) - 1: + end = spans[i + 1][0] + else: + end = len(text) + sentences.append(text[start:end]) + + return sentences + + return split + + +def split_by_regex(regex: str) -> Callable[[str], List[str]]: + """Split text by regex.""" + import re + + return lambda text: re.findall(regex, text) + + +def split_by_phrase_regex() -> Callable[[str], List[str]]: + """Split text by phrase regex. + + This regular expression will split the sentences into phrases, + where each phrase is a sequence of one or more non-comma, + non-period, and non-semicolon characters, followed by an optional comma, + period, or semicolon. The regular expression will also capture the + delimiters themselves as separate items in the list of phrases. + """ + regex = "[^,.;。]+[,.;。]?" + return split_by_regex(regex) + diff --git a/pilot/embedding_engine/loader/token_splitter.py b/pilot/embedding_engine/loader/token_splitter.py new file mode 100644 index 000000000..358d51790 --- /dev/null +++ b/pilot/embedding_engine/loader/token_splitter.py @@ -0,0 +1,184 @@ +"""Token splitter.""" +from typing import Callable, List, Optional + +from pydantic import Field, PrivateAttr, BaseModel + + +from pilot.common.global_helper import globals_helper +from pilot.embedding_engine.loader.splitter_utils import split_by_sep, split_by_char + +DEFAULT_METADATA_FORMAT_LEN = 2 +DEFAULT_CHUNK_OVERLAP = 20 +DEFAULT_CHUNK_SIZE = 1024 + + +class TokenTextSplitter(BaseModel): + """Implementation of splitting text that looks at word tokens.""" + + chunk_size: int = Field( + default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." + ) + chunk_overlap: int = Field( + default=DEFAULT_CHUNK_OVERLAP, + description="The token overlap of each chunk when splitting.", + ) + separator: str = Field( + default=" ", description="Default separator for splitting into words" + ) + backup_separators: List = Field( + default_factory=list, description="Additional separators for splitting." + ) + # callback_manager: CallbackManager = Field( + # default_factory=CallbackManager, exclude=True + # ) + tokenizer: Callable = Field( + default_factory=globals_helper.tokenizer, # type: ignore + description="Tokenizer for splitting words into tokens.", + exclude=True, + ) + + _split_fns: List[Callable] = PrivateAttr() + + def __init__( + self, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + tokenizer: Optional[Callable] = None, + # callback_manager: Optional[CallbackManager] = None, + separator: str = " ", + backup_separators: Optional[List[str]] = ["\n"], + ): + """Initialize with parameters.""" + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + # callback_manager = callback_manager or CallbackManager([]) + tokenizer = tokenizer or globals_helper.tokenizer + + all_seps = [separator] + (backup_separators or []) + self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] + + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separator=separator, + backup_separators=backup_separators, + # callback_manager=callback_manager, + tokenizer=tokenizer, + ) + + @classmethod + def class_name(cls) -> str: + return "TokenTextSplitter" + + def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: + """Split text into chunks, reserving space required for metadata str.""" + metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN + effective_chunk_size = self.chunk_size - metadata_len + if effective_chunk_size <= 0: + raise ValueError( + f"Metadata length ({metadata_len}) is longer than chunk size " + f"({self.chunk_size}). Consider increasing the chunk size or " + "decreasing the size of your metadata to avoid this." + ) + elif effective_chunk_size < 50: + print( + f"Metadata length ({metadata_len}) is close to chunk size " + f"({self.chunk_size}). Resulting chunks are less than 50 tokens. " + "Consider increasing the chunk size or decreasing the size of " + "your metadata to avoid this.", + flush=True, + ) + + return self._split_text(text, chunk_size=effective_chunk_size) + + def split_text(self, text: str) -> List[str]: + """Split text into chunks.""" + return self._split_text(text, chunk_size=self.chunk_size) + + def _split_text(self, text: str, chunk_size: int) -> List[str]: + """Split text into chunks up to chunk_size.""" + if text == "": + return [] + + splits = self._split(text, chunk_size) + chunks = self._merge(splits, chunk_size) + return chunks + + def _split(self, text: str, chunk_size: int) -> List[str]: + """Break text into splits that are smaller than chunk size. + + The order of splitting is: + 1. split by separator + 2. split by backup separators (if any) + 3. split by characters + + NOTE: the splits contain the separators. + """ + if len(self.tokenizer(text)) <= chunk_size: + return [text] + + for split_fn in self._split_fns: + splits = split_fn(text) + if len(splits) > 1: + break + + new_splits = [] + for split in splits: + split_len = len(self.tokenizer(split)) + if split_len <= chunk_size: + new_splits.append(split) + else: + # recursively split + new_splits.extend(self._split(split, chunk_size=chunk_size)) + return new_splits + + def _merge(self, splits: List[str], chunk_size: int) -> List[str]: + """Merge splits into chunks. + + The high-level idea is to keep adding splits to a chunk until we + exceed the chunk size, then we start a new chunk with overlap. + + When we start a new chunk, we pop off the first element of the previous + chunk until the total length is less than the chunk size. + """ + chunks: List[str] = [] + + cur_chunk: List[str] = [] + cur_len = 0 + for split in splits: + split_len = len(self.tokenizer(split)) + if split_len > chunk_size: + print( + f"Got a split of size {split_len}, ", + f"larger than chunk size {chunk_size}.", + ) + + # if we exceed the chunk size after adding the new split, then + # we need to end the current chunk and start a new one + if cur_len + split_len > chunk_size: + # end the previous chunk + chunk = "".join(cur_chunk).strip() + if chunk: + chunks.append(chunk) + + # start a new chunk with overlap + # keep popping off the first element of the previous chunk until: + # 1. the current chunk length is less than chunk overlap + # 2. the total length is less than chunk size + while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size: + # pop off the first element + first_chunk = cur_chunk.pop(0) + cur_len -= len(self.tokenizer(first_chunk)) + + cur_chunk.append(split) + cur_len += split_len + + # handle the last chunk + chunk = "".join(cur_chunk).strip() + if chunk: + chunks.append(chunk) + + return chunks diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py index cd5087f35..c2d7fb2cb 100644 --- a/pilot/scene/chat_knowledge/refine_summary/prompt.py +++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py @@ -12,11 +12,12 @@ CFG = Config() PROMPT_SCENE_DEFINE = """""" -_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请再完善一下原来的总结。\n回答的时候最好按照1.2.3.点进行总结""" +_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请根据你之前推理的内容进行最终的总结,总结的时候可以详细点,回答的时候最好按照1.2.3.进行总结.""" _DEFAULT_TEMPLATE_EN = """ -We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below.please refine the original summary. -\nWhen answering, it is best to summarize according to points 1.2.3. +We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below. +\nBased on the previous reasoning, please summarize the final conclusion in accordance with points 1, 2, and 3. and etc. + """ _DEFAULT_TEMPLATE = ( @@ -27,7 +28,7 @@ PROMPT_RESPONSE = """""" PROMPT_SEP = SeparatorStyle.SINGLE.value -PROMPT_NEED_NEED_STREAM_OUT = False +PROMPT_NEED_NEED_STREAM_OUT = True prompt = PromptTemplate( template_scene=ChatScene.ExtractRefineSummary.value(), diff --git a/pilot/scene/chat_knowledge/summary/chat.py b/pilot/scene/chat_knowledge/summary/chat.py index f887bde82..96e486ea8 100644 --- a/pilot/scene/chat_knowledge/summary/chat.py +++ b/pilot/scene/chat_knowledge/summary/chat.py @@ -21,7 +21,8 @@ class ExtractSummary(BaseChat): chat_param=chat_param, ) - self.user_input = chat_param["current_user_input"] + # self.user_input = chat_param["current_user_input"] + self.user_input = chat_param["select_param"] # self.extract_mode = chat_param["select_param"] def generate_input_values(self): diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py index 10a239586..8115de078 100644 --- a/pilot/scene/chat_knowledge/summary/prompt.py +++ b/pilot/scene/chat_knowledge/summary/prompt.py @@ -11,15 +11,15 @@ CFG = Config() PROMPT_SCENE_DEFINE = """""" -_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行总结: +_DEFAULT_TEMPLATE_ZH = """请根据提供的上下文信息的进行精简地总结: {context} -回答的时候最好按照1.2.3.点进行总结 +答案尽量精确和简单,不要过长,长度控制在100字左右 """ _DEFAULT_TEMPLATE_EN = """ -Write a summary of the following context: +Write a quick summary of the following context: {context} -When answering, it is best to summarize according to points 1.2.3. +the summary should be as concise as possible and not overly lengthy.Please keep the answer within approximately 200 characters. """ _DEFAULT_TEMPLATE = ( diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 8e5e52b58..164158694 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -10,6 +10,7 @@ from pilot.configs.model_config import ( EMBEDDING_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH, ) +from pilot.openapi.api_v1.api_v1 import no_stream_generator, stream_generator from pilot.openapi.api_view_model import Result from pilot.embedding_engine.embedding_engine import EmbeddingEngine @@ -25,9 +26,11 @@ from pilot.server.knowledge.request.request import ( DocumentQueryRequest, SpaceArgumentRequest, EntityExtractRequest, + DocumentSummaryRequest, ) from pilot.server.knowledge.request.request import KnowledgeSpaceRequest +from pilot.utils.tracer import root_tracer, SpanType logger = logging.getLogger(__name__) @@ -201,6 +204,45 @@ def similar_query(space_name: str, query_request: KnowledgeQueryRequest): return {"response": res} +@router.post("/knowledge/document/summary") +async def document_summary(request: DocumentSummaryRequest): + print(f"/document/summary params: {request}") + try: + with root_tracer.start_span( + "get_chat_instance", span_type=SpanType.CHAT, metadata=request + ): + chat = await knowledge_space_service.document_summary(request=request) + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + from starlette.responses import StreamingResponse + + if not chat.prompt_template.stream_out: + return StreamingResponse( + no_stream_generator(chat), + headers=headers, + media_type="text/event-stream", + ) + else: + return StreamingResponse( + stream_generator(chat, False, request.model_name), + headers=headers, + media_type="text/plain", + ) + + # return Result.succ( + # knowledge_space_service.create_knowledge_document( + # space=space_name, request=request + # ) + # ) + # return Result.succ([]) + except Exception as e: + return Result.faild(code="E000X", msg=f"document add error {e}") + + @router.post("/knowledge/entity/extract") async def entity_extract(request: EntityExtractRequest): logger.info(f"Received params: {request}") diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py index 032b97ba1..1e5bf46ed 100644 --- a/pilot/server/knowledge/request/request.py +++ b/pilot/server/knowledge/request/request.py @@ -108,6 +108,14 @@ class SpaceArgumentRequest(BaseModel): argument: str +class DocumentSummaryRequest(BaseModel): + """Sync request""" + + """doc_ids: doc ids""" + doc_id: int + model_name: str + + class EntityExtractRequest(BaseModel): """argument: argument""" diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index ec03419e5..15f07f07f 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -10,7 +10,7 @@ from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, ) from pilot.component import ComponentType -from pilot.utils.executor_utils import ExecutorFactory +from pilot.utils.executor_utils import ExecutorFactory, blocking_func_to_async from pilot.server.knowledge.chunk_db import ( DocumentChunkEntity, @@ -31,6 +31,7 @@ from pilot.server.knowledge.request.request import ( ChunkQueryRequest, SpaceArgumentRequest, DocumentSyncRequest, + DocumentSummaryRequest, ) from enum import Enum @@ -303,7 +304,30 @@ class KnowledgeService: ] document_chunk_dao.create_documents_chunks(chunk_entities) - return True + return doc.id + + async def document_summary(self, request: DocumentSummaryRequest): + """get document summary + Args: + - request: DocumentSummaryRequest + """ + doc_query = KnowledgeDocumentEntity(id=request.doc_id) + documents = knowledge_document_dao.get_documents(doc_query) + if len(documents) != 1: + raise Exception(f"can not found document for {request.doc_id}") + document = documents[0] + query = DocumentChunkEntity( + document_id=request.doc_id, + ) + chunks = document_chunk_dao.get_document_chunks(query, page=1, page_size=100) + if len(chunks) == 0: + raise Exception(f"can not found chunks for {request.doc_id}") + from langchain.schema import Document + + chunk_docs = [Document(page_content=chunk.content) for chunk in chunks] + return await self.async_document_summary( + model_name=request.model_name, chunk_docs=chunk_docs, doc=document + ) def update_knowledge_space( self, space_id: int, space_request: KnowledgeSpaceRequest @@ -417,30 +441,25 @@ class KnowledgeService: logger.error(f"document build graph failed:{doc.doc_name}, {str(e)}") return knowledge_document_dao.update_knowledge_document(doc) - def async_document_summary(self, chunk_docs, doc): + async def async_document_summary(self, model_name, chunk_docs, doc): """async document extract summary Args: + - model_name: str - chunk_docs: List[Document] - doc: KnowledgeDocumentEntity """ - from llama_index import PromptHelper - from llama_index.prompts.default_prompt_selectors import ( - DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, - ) - texts = [doc.page_content for doc in chunk_docs] - prompt_helper = PromptHelper(context_window=2000) + from pilot.common.prompt_util import PromptHelper - texts = prompt_helper.repack( - prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=texts - ) + prompt_helper = PromptHelper() + from pilot.scene.chat_knowledge.summary.prompt import prompt + + texts = prompt_helper.repack(prompt_template=prompt.template, text_chunks=texts) logger.info( f"async_document_summary, doc:{doc.doc_name}, chunk_size:{len(texts)}, begin generate summary" ) - summary = self._mapreduce_extract_summary(texts) - print(f"final summary:{summary}") - doc.summary = summary - return knowledge_document_dao.update_knowledge_document(doc) + summary = await self._mapreduce_extract_summary(texts, model_name, 10, 3) + return await self._llm_extract_summary(summary, model_name) def async_doc_embedding(self, client, chunk_docs, doc): """async document embedding into vector db @@ -506,30 +525,37 @@ class KnowledgeService: return json.loads(spaces[0].context) return None - def _llm_extract_summary(self, doc: str): + async def _llm_extract_summary(self, doc: str, model_name: str = None): """Extract triplets from text by llm""" from pilot.scene.base import ChatScene - from pilot.common.chat_util import llm_chat_response_nostream import uuid chat_param = { "chat_session_id": uuid.uuid1(), - "current_user_input": doc, + "current_user_input": "", "select_param": doc, - "model_name": self.model_name, + "model_name": model_name, } - from pilot.common.chat_util import run_async_tasks + executor = CFG.SYSTEM_APP.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + from pilot.openapi.api_v1.api_v1 import CHAT_FACTORY - summary_iters = run_async_tasks( - [ - llm_chat_response_nostream( - ChatScene.ExtractRefineSummary.value(), **{"chat_param": chat_param} - ) - ] + chat = await blocking_func_to_async( + executor, + CHAT_FACTORY.get_implementation, + ChatScene.ExtractRefineSummary.value(), + **{"chat_param": chat_param}, ) - return summary_iters[0] + return chat - def _mapreduce_extract_summary(self, docs): + async def _mapreduce_extract_summary( + self, + docs, + model_name: str = None, + max_iteration: int = 5, + concurrency_limit: int = None, + ): """Extract summary by mapreduce mode map -> multi async thread generate summary reduce -> merge the summaries by map process @@ -541,18 +567,16 @@ class KnowledgeService: import uuid tasks = [] - max_iteration = 5 if len(docs) == 1: - summary = self._llm_extract_summary(doc=docs[0]) - return summary + return docs[0] else: max_iteration = max_iteration if len(docs) > max_iteration else len(docs) for doc in docs[0:max_iteration]: chat_param = { "chat_session_id": uuid.uuid1(), - "current_user_input": doc, - "select_param": "summary", - "model_name": self.model_name, + "current_user_input": "", + "select_param": doc, + "model_name": model_name, } tasks.append( llm_chat_response_nostream( @@ -561,14 +585,22 @@ class KnowledgeService: ) from pilot.common.chat_util import run_async_tasks - summary_iters = run_async_tasks(tasks) + summary_iters = await run_async_tasks( + tasks=tasks, concurrency_limit=concurrency_limit + ) + summary_iters = list( + filter( + lambda content: "LLMServer Generate Error" not in content, + summary_iters, + ) + ) from pilot.common.prompt_util import PromptHelper - from llama_index.prompts.default_prompt_selectors import ( - DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, - ) + from pilot.scene.chat_knowledge.summary.prompt import prompt - prompt_helper = PromptHelper(context_window=2500) + prompt_helper = PromptHelper() summary_iters = prompt_helper.repack( - prompt=DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, text_chunks=summary_iters + prompt_template=prompt.template, text_chunks=summary_iters + ) + return await self._mapreduce_extract_summary( + summary_iters, model_name, max_iteration, concurrency_limit ) - return self._mapreduce_extract_summary(summary_iters)