mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 21:14:06 +00:00
449 lines
13 KiB
Python
449 lines
13 KiB
Python
"""General util functions."""
|
|
|
|
import asyncio
|
|
import os
|
|
import random
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from functools import partial, wraps
|
|
from itertools import islice
|
|
from pathlib import Path
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Set,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
|
|
class GlobalsHelper:
|
|
"""Helper to retrieve globals.
|
|
|
|
Helpful for global caching of certain variables that can be expensive to load.
|
|
(e.g. tokenization)
|
|
|
|
"""
|
|
|
|
_tokenizer: Optional[Callable[[str], List]] = None
|
|
_stopwords: Optional[List[str]] = None
|
|
|
|
@property
|
|
def tokenizer(self) -> Callable[[str], List]:
|
|
"""Get tokenizer."""
|
|
if self._tokenizer is None:
|
|
tiktoken_import_err = (
|
|
"`tiktoken` package not found, please run `pip install tiktoken`"
|
|
)
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ImportError(tiktoken_import_err)
|
|
enc = tiktoken.get_encoding("gpt2")
|
|
self._tokenizer = cast(Callable[[str], List], enc.encode)
|
|
self._tokenizer = partial(self._tokenizer, allowed_special="all")
|
|
return self._tokenizer # type: ignore
|
|
|
|
@property
|
|
def stopwords(self) -> List[str]:
|
|
"""Get stopwords."""
|
|
if self._stopwords is None:
|
|
try:
|
|
import nltk
|
|
from nltk.corpus import stopwords
|
|
except ImportError:
|
|
raise ImportError(
|
|
"`nltk` package not found, please run `pip install nltk`"
|
|
)
|
|
|
|
from llama_index.utils import get_cache_dir
|
|
|
|
cache_dir = get_cache_dir()
|
|
nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir)
|
|
|
|
# update nltk path for nltk so that it finds the data
|
|
if nltk_data_dir not in nltk.data.path:
|
|
nltk.data.path.append(nltk_data_dir)
|
|
|
|
try:
|
|
nltk.data.find("corpora/stopwords")
|
|
except LookupError:
|
|
nltk.download("stopwords", download_dir=nltk_data_dir)
|
|
self._stopwords = stopwords.words("english")
|
|
return self._stopwords
|
|
|
|
|
|
globals_helper = GlobalsHelper()
|
|
|
|
|
|
def get_new_id(d: Set) -> str:
|
|
"""Get a new ID."""
|
|
while True:
|
|
new_id = str(uuid.uuid4())
|
|
if new_id not in d:
|
|
break
|
|
return new_id
|
|
|
|
|
|
def get_new_int_id(d: Set) -> int:
|
|
"""Get a new integer ID."""
|
|
while True:
|
|
new_id = random.randint(0, sys.maxsize)
|
|
if new_id not in d:
|
|
break
|
|
return new_id
|
|
|
|
|
|
@contextmanager
|
|
def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator:
|
|
"""Temporary setter.
|
|
|
|
Utility class for setting a temporary value for an attribute on a class.
|
|
Taken from: https://tinyurl.com/2p89xymh
|
|
|
|
"""
|
|
prev_values = {k: getattr(obj, k) for k in kwargs}
|
|
for k, v in kwargs.items():
|
|
setattr(obj, k, v)
|
|
try:
|
|
yield
|
|
finally:
|
|
for k, v in prev_values.items():
|
|
setattr(obj, k, v)
|
|
|
|
|
|
@dataclass
|
|
class ErrorToRetry:
|
|
"""Exception types that should be retried.
|
|
|
|
Args:
|
|
exception_cls (Type[Exception]): Class of exception.
|
|
check_fn (Optional[Callable[[Any]], bool]]):
|
|
A function that takes an exception instance as input and returns
|
|
whether to retry.
|
|
|
|
"""
|
|
|
|
exception_cls: Type[Exception]
|
|
check_fn: Optional[Callable[[Any], bool]] = None
|
|
|
|
|
|
def retry_on_exceptions_with_backoff(
|
|
lambda_fn: Callable,
|
|
errors_to_retry: List[ErrorToRetry],
|
|
max_tries: int = 10,
|
|
min_backoff_secs: float = 0.5,
|
|
max_backoff_secs: float = 60.0,
|
|
) -> Any:
|
|
"""Execute lambda function with retries and exponential backoff.
|
|
|
|
Args:
|
|
lambda_fn (Callable): Function to be called and output we want.
|
|
errors_to_retry (List[ErrorToRetry]): List of errors to retry.
|
|
At least one needs to be provided.
|
|
max_tries (int): Maximum number of tries, including the first. Defaults to 10.
|
|
min_backoff_secs (float): Minimum amount of backoff time between attempts.
|
|
Defaults to 0.5.
|
|
max_backoff_secs (float): Maximum amount of backoff time between attempts.
|
|
Defaults to 60.
|
|
|
|
"""
|
|
if not errors_to_retry:
|
|
raise ValueError("At least one error to retry needs to be provided")
|
|
|
|
error_checks = {
|
|
error_to_retry.exception_cls: error_to_retry.check_fn
|
|
for error_to_retry in errors_to_retry
|
|
}
|
|
exception_class_tuples = tuple(error_checks.keys())
|
|
|
|
backoff_secs = min_backoff_secs
|
|
tries = 0
|
|
|
|
while True:
|
|
try:
|
|
return lambda_fn()
|
|
except exception_class_tuples as e:
|
|
traceback.print_exc()
|
|
tries += 1
|
|
if tries >= max_tries:
|
|
raise
|
|
check_fn = error_checks.get(e.__class__)
|
|
if check_fn and not check_fn(e):
|
|
raise
|
|
time.sleep(backoff_secs)
|
|
backoff_secs = min(backoff_secs * 2, max_backoff_secs)
|
|
|
|
|
|
def truncate_text(text: str, max_length: int) -> str:
|
|
"""Truncate text to a maximum length."""
|
|
if len(text) <= max_length:
|
|
return text
|
|
return text[: max_length - 3] + "..."
|
|
|
|
|
|
def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable:
|
|
"""Iterate over an iterable in batches.
|
|
|
|
>>> list(iter_batch([1, 2, 3, 4, 5], 3))
|
|
[[1, 2, 3], [4, 5]]
|
|
"""
|
|
source_iter = iter(iterable)
|
|
while source_iter:
|
|
b = list(islice(source_iter, size))
|
|
if len(b) == 0:
|
|
break
|
|
yield b
|
|
|
|
|
|
def concat_dirs(dirname: str, basename: str) -> str:
|
|
"""
|
|
Append basename to dirname, avoiding backslashes when running on windows.
|
|
|
|
os.path.join(dirname, basename) will add a backslash before dirname if
|
|
basename does not end with a slash, so we make sure it does.
|
|
"""
|
|
dirname += "/" if dirname[-1] != "/" else ""
|
|
return os.path.join(dirname, basename)
|
|
|
|
|
|
def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable:
|
|
"""
|
|
Optionally get a tqdm iterable. Ensures tqdm.auto is used.
|
|
"""
|
|
_iterator = items
|
|
if show_progress:
|
|
try:
|
|
from tqdm.auto import tqdm
|
|
|
|
return tqdm(items, desc=desc)
|
|
except ImportError:
|
|
pass
|
|
return _iterator
|
|
|
|
|
|
def count_tokens(text: str) -> int:
|
|
tokens = globals_helper.tokenizer(text)
|
|
return len(tokens)
|
|
|
|
|
|
def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]:
|
|
"""
|
|
Args:
|
|
model_name(str): the model name of the tokenizer.
|
|
For instance, fxmarty/tiny-llama-fast-tokenizer.
|
|
"""
|
|
try:
|
|
from transformers import AutoTokenizer
|
|
except ImportError:
|
|
raise ValueError(
|
|
"`transformers` package not found, please run `pip install transformers`"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
return tokenizer.tokenize
|
|
|
|
|
|
def get_cache_dir() -> str:
|
|
"""Locate a platform-appropriate cache directory for llama_index,
|
|
and create it if it doesn't yet exist.
|
|
"""
|
|
# User override
|
|
if "LLAMA_INDEX_CACHE_DIR" in os.environ:
|
|
path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"])
|
|
|
|
# Linux, Unix, AIX, etc.
|
|
elif os.name == "posix" and sys.platform != "darwin":
|
|
path = Path("/tmp/llama_index")
|
|
|
|
# Mac OS
|
|
elif sys.platform == "darwin":
|
|
path = Path(os.path.expanduser("~"), "Library/Caches/llama_index")
|
|
|
|
# Windows (hopefully)
|
|
else:
|
|
local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser(
|
|
"~\\AppData\\Local"
|
|
)
|
|
path = Path(local, "llama_index")
|
|
|
|
if not os.path.exists(path):
|
|
os.makedirs(
|
|
path, exist_ok=True
|
|
) # prevents https://github.com/jerryjliu/llama_index/issues/7362
|
|
return str(path)
|
|
|
|
|
|
def add_sync_version(func: Any) -> Any:
|
|
"""Decorator for adding sync version of an async function. The sync version
|
|
is added as a function attribute to the original function, func.
|
|
|
|
Args:
|
|
func(Any): the async function for which a sync variant will be built.
|
|
"""
|
|
assert asyncio.iscoroutinefunction(func)
|
|
|
|
@wraps(func)
|
|
def _wrapper(*args: Any, **kwds: Any) -> Any:
|
|
return asyncio.get_event_loop().run_until_complete(func(*args, **kwds))
|
|
|
|
func.sync = _wrapper
|
|
return func
|
|
|
|
|
|
# Sample text from llama_index's readme
|
|
SAMPLE_TEXT = """
|
|
Context
|
|
LLMs are a phenomenal piece of technology for knowledge generation and reasoning.
|
|
They are pre-trained on large amounts of publicly available data.
|
|
How do we best augment LLMs with our own private data?
|
|
We need a comprehensive toolkit to help perform this data augmentation for LLMs.
|
|
|
|
Proposed Solution
|
|
That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help
|
|
you build LLM apps. It provides the following tools:
|
|
|
|
Offers data connectors to ingest your existing data sources and data formats
|
|
(APIs, PDFs, docs, SQL, etc.)
|
|
Provides ways to structure your data (indices, graphs) so that this data can be
|
|
easily used with LLMs.
|
|
Provides an advanced retrieval/query interface over your data:
|
|
Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output.
|
|
Allows easy integrations with your outer application framework
|
|
(e.g. with LangChain, Flask, Docker, ChatGPT, anything else).
|
|
LlamaIndex provides tools for both beginner users and advanced users.
|
|
Our high-level API allows beginner users to use LlamaIndex to ingest and
|
|
query their data in 5 lines of code. Our lower-level APIs allow advanced users to
|
|
customize and extend any module (data connectors, indices, retrievers, query engines,
|
|
reranking modules), to fit their needs.
|
|
"""
|
|
|
|
_LLAMA_INDEX_COLORS = {
|
|
"llama_pink": "38;2;237;90;200",
|
|
"llama_blue": "38;2;90;149;237",
|
|
"llama_turquoise": "38;2;11;159;203",
|
|
"llama_lavender": "38;2;155;135;227",
|
|
}
|
|
|
|
_ANSI_COLORS = {
|
|
"red": "31",
|
|
"green": "32",
|
|
"yellow": "33",
|
|
"blue": "34",
|
|
"magenta": "35",
|
|
"cyan": "36",
|
|
"pink": "38;5;200",
|
|
}
|
|
|
|
|
|
def get_color_mapping(
|
|
items: List[str], use_llama_index_colors: bool = True
|
|
) -> Dict[str, str]:
|
|
"""
|
|
Get a mapping of items to colors.
|
|
|
|
Args:
|
|
items (List[str]): List of items to be mapped to colors.
|
|
use_llama_index_colors (bool, optional): Flag to indicate
|
|
whether to use LlamaIndex colors or ANSI colors.
|
|
Defaults to True.
|
|
|
|
Returns:
|
|
Dict[str, str]: Mapping of items to colors.
|
|
"""
|
|
if use_llama_index_colors:
|
|
color_palette = _LLAMA_INDEX_COLORS
|
|
else:
|
|
color_palette = _ANSI_COLORS
|
|
|
|
colors = list(color_palette.keys())
|
|
return {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
|
|
|
|
|
def _get_colored_text(text: str, color: str) -> str:
|
|
"""
|
|
Get the colored version of the input text.
|
|
|
|
Args:
|
|
text (str): Input text.
|
|
color (str): Color to be applied to the text.
|
|
|
|
Returns:
|
|
str: Colored version of the input text.
|
|
"""
|
|
all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS}
|
|
|
|
if color not in all_colors:
|
|
return f"\033[1;3m{text}\033[0m" # just bolded and italicized
|
|
|
|
color = all_colors[color]
|
|
|
|
return f"\033[1;3;{color}m{text}\033[0m"
|
|
|
|
|
|
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
|
"""
|
|
Print the text with the specified color.
|
|
|
|
Args:
|
|
text (str): Text to be printed.
|
|
color (str, optional): Color to be applied to the text. Supported colors are:
|
|
llama_pink, llama_blue, llama_turquoise, llama_lavender,
|
|
red, green, yellow, blue, magenta, cyan, pink.
|
|
end (str, optional): String appended after the last character of the text.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
text_to_print = _get_colored_text(text, color) if color is not None else text
|
|
print(text_to_print, end=end)
|
|
|
|
|
|
def infer_torch_device() -> str:
|
|
"""Infer the input to torch.device."""
|
|
try:
|
|
has_cuda = torch.cuda.is_available()
|
|
except NameError:
|
|
import torch
|
|
|
|
has_cuda = torch.cuda.is_available()
|
|
if has_cuda:
|
|
return "cuda"
|
|
if torch.backends.mps.is_available():
|
|
return "mps"
|
|
return "cpu"
|
|
|
|
|
|
def unit_generator(x: Any) -> Generator[Any, None, None]:
|
|
"""A function that returns a generator of a single element.
|
|
|
|
Args:
|
|
x (Any): the element to build yield
|
|
|
|
Yields:
|
|
Any: the single element
|
|
"""
|
|
yield x
|
|
|
|
|
|
async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]:
|
|
"""A function that returns a generator of a single element.
|
|
|
|
Args:
|
|
x (Any): the element to build yield
|
|
|
|
Yields:
|
|
Any: the single element
|
|
"""
|
|
yield x
|