mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-29 22:00:48 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			449 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			449 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """General util functions."""
 | |
| 
 | |
| import asyncio
 | |
| import os
 | |
| import random
 | |
| import sys
 | |
| import time
 | |
| import traceback
 | |
| import uuid
 | |
| from contextlib import contextmanager
 | |
| from dataclasses import dataclass
 | |
| from functools import partial, wraps
 | |
| from itertools import islice
 | |
| from pathlib import Path
 | |
| from typing import (
 | |
|     Any,
 | |
|     AsyncGenerator,
 | |
|     Callable,
 | |
|     Dict,
 | |
|     Generator,
 | |
|     Iterable,
 | |
|     List,
 | |
|     Optional,
 | |
|     Set,
 | |
|     Type,
 | |
|     Union,
 | |
|     cast,
 | |
| )
 | |
| 
 | |
| 
 | |
| class GlobalsHelper:
 | |
|     """Helper to retrieve globals.
 | |
| 
 | |
|     Helpful for global caching of certain variables that can be expensive to load.
 | |
|     (e.g. tokenization)
 | |
| 
 | |
|     """
 | |
| 
 | |
|     _tokenizer: Optional[Callable[[str], List]] = None
 | |
|     _stopwords: Optional[List[str]] = None
 | |
| 
 | |
|     @property
 | |
|     def tokenizer(self) -> Callable[[str], List]:
 | |
|         """Get tokenizer."""
 | |
|         if self._tokenizer is None:
 | |
|             tiktoken_import_err = (
 | |
|                 "`tiktoken` package not found, please run `pip install tiktoken`"
 | |
|             )
 | |
|             try:
 | |
|                 import tiktoken
 | |
|             except ImportError:
 | |
|                 raise ImportError(tiktoken_import_err)
 | |
|             enc = tiktoken.get_encoding("gpt2")
 | |
|             self._tokenizer = cast(Callable[[str], List], enc.encode)
 | |
|             self._tokenizer = partial(self._tokenizer, allowed_special="all")
 | |
|         return self._tokenizer  # type: ignore
 | |
| 
 | |
|     @property
 | |
|     def stopwords(self) -> List[str]:
 | |
|         """Get stopwords."""
 | |
|         if self._stopwords is None:
 | |
|             try:
 | |
|                 import nltk
 | |
|                 from nltk.corpus import stopwords
 | |
|             except ImportError:
 | |
|                 raise ImportError(
 | |
|                     "`nltk` package not found, please run `pip install nltk`"
 | |
|                 )
 | |
| 
 | |
|             from llama_index.utils import get_cache_dir
 | |
| 
 | |
|             cache_dir = get_cache_dir()
 | |
|             nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
 | |
| 
 | |
|             # update nltk path for nltk so that it finds the data
 | |
|             if nltk_data_dir not in nltk.data.path:
 | |
|                 nltk.data.path.append(nltk_data_dir)
 | |
| 
 | |
|             try:
 | |
|                 nltk.data.find("corpora/stopwords")
 | |
|             except LookupError:
 | |
|                 nltk.download("stopwords", download_dir=nltk_data_dir)
 | |
|             self._stopwords = stopwords.words("english")
 | |
|         return self._stopwords
 | |
| 
 | |
| 
 | |
| globals_helper = GlobalsHelper()
 | |
| 
 | |
| 
 | |
| def get_new_id(d: Set) -> str:
 | |
|     """Get a new ID."""
 | |
|     while True:
 | |
|         new_id = str(uuid.uuid4())
 | |
|         if new_id not in d:
 | |
|             break
 | |
|     return new_id
 | |
| 
 | |
| 
 | |
| def get_new_int_id(d: Set) -> int:
 | |
|     """Get a new integer ID."""
 | |
|     while True:
 | |
|         new_id = random.randint(0, sys.maxsize)
 | |
|         if new_id not in d:
 | |
|             break
 | |
|     return new_id
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
 | |
|     """Temporary setter.
 | |
| 
 | |
|     Utility class for setting a temporary value for an attribute on a class.
 | |
|     Taken from: https://tinyurl.com/2p89xymh
 | |
| 
 | |
|     """
 | |
|     prev_values = {k: getattr(obj, k) for k in kwargs}
 | |
|     for k, v in kwargs.items():
 | |
|         setattr(obj, k, v)
 | |
|     try:
 | |
|         yield
 | |
|     finally:
 | |
|         for k, v in prev_values.items():
 | |
|             setattr(obj, k, v)
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class ErrorToRetry:
 | |
|     """Exception types that should be retried.
 | |
| 
 | |
|     Args:
 | |
|         exception_cls (Type[Exception]): Class of exception.
 | |
|         check_fn (Optional[Callable[[Any]], bool]]):
 | |
|             A function that takes an exception instance as input and returns
 | |
|             whether to retry.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     exception_cls: Type[Exception]
 | |
|     check_fn: Optional[Callable[[Any], bool]] = None
 | |
| 
 | |
| 
 | |
| def retry_on_exceptions_with_backoff(
 | |
|     lambda_fn: Callable,
 | |
|     errors_to_retry: List[ErrorToRetry],
 | |
|     max_tries: int = 10,
 | |
|     min_backoff_secs: float = 0.5,
 | |
|     max_backoff_secs: float = 60.0,
 | |
| ) -> Any:
 | |
|     """Execute lambda function with retries and exponential backoff.
 | |
| 
 | |
|     Args:
 | |
|         lambda_fn (Callable): Function to be called and output we want.
 | |
|         errors_to_retry (List[ErrorToRetry]): List of errors to retry.
 | |
|             At least one needs to be provided.
 | |
|         max_tries (int): Maximum number of tries, including the first. Defaults to 10.
 | |
|         min_backoff_secs (float): Minimum amount of backoff time between attempts.
 | |
|             Defaults to 0.5.
 | |
|         max_backoff_secs (float): Maximum amount of backoff time between attempts.
 | |
|             Defaults to 60.
 | |
| 
 | |
|     """
 | |
|     if not errors_to_retry:
 | |
|         raise ValueError("At least one error to retry needs to be provided")
 | |
| 
 | |
|     error_checks = {
 | |
|         error_to_retry.exception_cls: error_to_retry.check_fn
 | |
|         for error_to_retry in errors_to_retry
 | |
|     }
 | |
|     exception_class_tuples = tuple(error_checks.keys())
 | |
| 
 | |
|     backoff_secs = min_backoff_secs
 | |
|     tries = 0
 | |
| 
 | |
|     while True:
 | |
|         try:
 | |
|             return lambda_fn()
 | |
|         except exception_class_tuples as e:
 | |
|             traceback.print_exc()
 | |
|             tries += 1
 | |
|             if tries >= max_tries:
 | |
|                 raise
 | |
|             check_fn = error_checks.get(e.__class__)
 | |
|             if check_fn and not check_fn(e):
 | |
|                 raise
 | |
|             time.sleep(backoff_secs)
 | |
|             backoff_secs = min(backoff_secs * 2, max_backoff_secs)
 | |
| 
 | |
| 
 | |
| def truncate_text(text: str, max_length: int) -> str:
 | |
|     """Truncate text to a maximum length."""
 | |
|     if len(text) <= max_length:
 | |
|         return text
 | |
|     return text[: max_length - 3] + "..."
 | |
| 
 | |
| 
 | |
| def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
 | |
|     """Iterate over an iterable in batches.
 | |
| 
 | |
|     >>> list(iter_batch([1, 2, 3, 4, 5], 3))
 | |
|     [[1, 2, 3], [4, 5]]
 | |
|     """
 | |
|     source_iter = iter(iterable)
 | |
|     while source_iter:
 | |
|         b = list(islice(source_iter, size))
 | |
|         if len(b) == 0:
 | |
|             break
 | |
|         yield b
 | |
| 
 | |
| 
 | |
| def concat_dirs(dirname: str, basename: str) -> str:
 | |
|     """
 | |
|     Append basename to dirname, avoiding backslashes when running on windows.
 | |
| 
 | |
|     os.path.join(dirname, basename) will add a backslash before dirname if
 | |
|     basename does not end with a slash, so we make sure it does.
 | |
|     """
 | |
|     dirname += "/" if dirname[-1] != "/" else ""
 | |
|     return os.path.join(dirname, basename)
 | |
| 
 | |
| 
 | |
| def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
 | |
|     """
 | |
|     Optionally get a tqdm iterable. Ensures tqdm.auto is used.
 | |
|     """
 | |
|     _iterator = items
 | |
|     if show_progress:
 | |
|         try:
 | |
|             from tqdm.auto import tqdm
 | |
| 
 | |
|             return tqdm(items, desc=desc)
 | |
|         except ImportError:
 | |
|             pass
 | |
|     return _iterator
 | |
| 
 | |
| 
 | |
| def count_tokens(text: str) -> int:
 | |
|     tokens = globals_helper.tokenizer(text)
 | |
|     return len(tokens)
 | |
| 
 | |
| 
 | |
| def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
 | |
|     """
 | |
|     Args:
 | |
|         model_name(str): the model name of the tokenizer.
 | |
|                         For instance, fxmarty/tiny-llama-fast-tokenizer.
 | |
|     """
 | |
|     try:
 | |
|         from transformers import AutoTokenizer
 | |
|     except ImportError:
 | |
|         raise ValueError(
 | |
|             "`transformers` package not found, please run `pip install transformers`"
 | |
|         )
 | |
|     tokenizer = AutoTokenizer.from_pretrained(model_name)
 | |
|     return tokenizer.tokenize
 | |
| 
 | |
| 
 | |
| def get_cache_dir() -> str:
 | |
|     """Locate a platform-appropriate cache directory for llama_index,
 | |
|     and create it if it doesn't yet exist.
 | |
|     """
 | |
|     # User override
 | |
|     if "LLAMA_INDEX_CACHE_DIR" in os.environ:
 | |
|         path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
 | |
| 
 | |
|     # Linux, Unix, AIX, etc.
 | |
|     elif os.name == "posix" and sys.platform != "darwin":
 | |
|         path = Path("/tmp/llama_index")
 | |
| 
 | |
|     # Mac OS
 | |
|     elif sys.platform == "darwin":
 | |
|         path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
 | |
| 
 | |
|     # Windows (hopefully)
 | |
|     else:
 | |
|         local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
 | |
|             "~\\AppData\\Local"
 | |
|         )
 | |
|         path = Path(local, "llama_index")
 | |
| 
 | |
|     if not os.path.exists(path):
 | |
|         os.makedirs(
 | |
|             path, exist_ok=True
 | |
|         )  # prevents https://github.com/jerryjliu/llama_index/issues/7362
 | |
|     return str(path)
 | |
| 
 | |
| 
 | |
| def add_sync_version(func: Any) -> Any:
 | |
|     """Decorator for adding sync version of an async function. The sync version
 | |
|     is added as a function attribute to the original function, func.
 | |
| 
 | |
|     Args:
 | |
|         func(Any): the async function for which a sync variant will be built.
 | |
|     """
 | |
|     assert asyncio.iscoroutinefunction(func)
 | |
| 
 | |
|     @wraps(func)
 | |
|     def _wrapper(*args: Any, **kwds: Any) -> Any:
 | |
|         return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
 | |
| 
 | |
|     func.sync = _wrapper
 | |
|     return func
 | |
| 
 | |
| 
 | |
| # Sample text from llama_index's readme
 | |
| SAMPLE_TEXT = """
 | |
| Context
 | |
| LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
 | |
| They are pre-trained on large amounts of publicly available data.
 | |
| How do we best augment LLMs with our own private data?
 | |
| We need a comprehensive toolkit to help perform this data augmentation for LLMs.
 | |
| 
 | |
| Proposed Solution
 | |
| That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
 | |
| you build LLM  apps. It provides the following tools:
 | |
| 
 | |
| Offers data connectors to ingest your existing data sources and data formats
 | |
| (APIs, PDFs, docs, SQL, etc.)
 | |
| Provides ways to structure your data (indices, graphs) so that this data can be
 | |
| easily used with LLMs.
 | |
| Provides an advanced retrieval/query interface over your data:
 | |
| Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
 | |
| Allows easy integrations with your outer application framework
 | |
| (e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
 | |
| LlamaIndex provides tools for both beginner users and advanced users.
 | |
| Our high-level API allows beginner users to use LlamaIndex to ingest and
 | |
| query their data in 5 lines of code. Our lower-level APIs allow advanced users to
 | |
| customize and extend any module (data connectors, indices, retrievers, query engines,
 | |
| reranking modules), to fit their needs.
 | |
| """
 | |
| 
 | |
| _LLAMA_INDEX_COLORS = {
 | |
|     "llama_pink": "38;2;237;90;200",
 | |
|     "llama_blue": "38;2;90;149;237",
 | |
|     "llama_turquoise": "38;2;11;159;203",
 | |
|     "llama_lavender": "38;2;155;135;227",
 | |
| }
 | |
| 
 | |
| _ANSI_COLORS = {
 | |
|     "red": "31",
 | |
|     "green": "32",
 | |
|     "yellow": "33",
 | |
|     "blue": "34",
 | |
|     "magenta": "35",
 | |
|     "cyan": "36",
 | |
|     "pink": "38;5;200",
 | |
| }
 | |
| 
 | |
| 
 | |
| def get_color_mapping(
 | |
|     items: List[str], use_llama_index_colors: bool = True
 | |
| ) -> Dict[str, str]:
 | |
|     """
 | |
|     Get a mapping of items to colors.
 | |
| 
 | |
|     Args:
 | |
|         items (List[str]): List of items to be mapped to colors.
 | |
|         use_llama_index_colors (bool, optional): Flag to indicate
 | |
|         whether to use LlamaIndex colors or ANSI colors.
 | |
|             Defaults to True.
 | |
| 
 | |
|     Returns:
 | |
|         Dict[str, str]: Mapping of items to colors.
 | |
|     """
 | |
|     if use_llama_index_colors:
 | |
|         color_palette = _LLAMA_INDEX_COLORS
 | |
|     else:
 | |
|         color_palette = _ANSI_COLORS
 | |
| 
 | |
|     colors = list(color_palette.keys())
 | |
|     return {item: colors[i % len(colors)] for i, item in enumerate(items)}
 | |
| 
 | |
| 
 | |
| def _get_colored_text(text: str, color: str) -> str:
 | |
|     """
 | |
|     Get the colored version of the input text.
 | |
| 
 | |
|     Args:
 | |
|         text (str): Input text.
 | |
|         color (str): Color to be applied to the text.
 | |
| 
 | |
|     Returns:
 | |
|         str: Colored version of the input text.
 | |
|     """
 | |
|     all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
 | |
| 
 | |
|     if color not in all_colors:
 | |
|         return f"\033[1;3m{text}\033[0m"  # just bolded and italicized
 | |
| 
 | |
|     color = all_colors[color]
 | |
| 
 | |
|     return f"\033[1;3;{color}m{text}\033[0m"
 | |
| 
 | |
| 
 | |
| def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
 | |
|     """
 | |
|     Print the text with the specified color.
 | |
| 
 | |
|     Args:
 | |
|         text (str): Text to be printed.
 | |
|         color (str, optional): Color to be applied to the text. Supported colors are:
 | |
|             llama_pink, llama_blue, llama_turquoise, llama_lavender,
 | |
|             red, green, yellow, blue, magenta, cyan, pink.
 | |
|         end (str, optional): String appended after the last character of the text.
 | |
| 
 | |
|     Returns:
 | |
|         None
 | |
|     """
 | |
|     text_to_print = _get_colored_text(text, color) if color is not None else text
 | |
|     print(text_to_print, end=end)
 | |
| 
 | |
| 
 | |
| def infer_torch_device() -> str:
 | |
|     """Infer the input to torch.device."""
 | |
|     try:
 | |
|         has_cuda = torch.cuda.is_available()
 | |
|     except NameError:
 | |
|         import torch
 | |
| 
 | |
|         has_cuda = torch.cuda.is_available()
 | |
|     if has_cuda:
 | |
|         return "cuda"
 | |
|     if torch.backends.mps.is_available():
 | |
|         return "mps"
 | |
|     return "cpu"
 | |
| 
 | |
| 
 | |
| def unit_generator(x: Any) -> Generator[Any, None, None]:
 | |
|     """A function that returns a generator of a single element.
 | |
| 
 | |
|     Args:
 | |
|         x (Any): the element to build yield
 | |
| 
 | |
|     Yields:
 | |
|         Any: the single element
 | |
|     """
 | |
|     yield x
 | |
| 
 | |
| 
 | |
| async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
 | |
|     """A function that returns a generator of a single element.
 | |
| 
 | |
|     Args:
 | |
|         x (Any): the element to build yield
 | |
| 
 | |
|     Yields:
 | |
|         Any: the single element
 | |
|     """
 | |
|     yield x
 |