"""General util functions.""" import asyncio import os import random import sys import time import traceback import uuid from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps from itertools import islice from pathlib import Path from typing import ( Any, AsyncGenerator, Callable, Dict, Generator, Iterable, List, Optional, Set, Type, Union, cast, ) class GlobalsHelper: """Helper to retrieve globals. Helpful for global caching of certain variables that can be expensive to load. (e.g. tokenization) """ _tokenizer: Optional[Callable[[str], List]] = None _stopwords: Optional[List[str]] = None @property def tokenizer(self) -> Callable[[str], List]: """Get tokenizer.""" if self._tokenizer is None: tiktoken_import_err = ( "`tiktoken` package not found, please run `pip install tiktoken`" ) try: import tiktoken except ImportError: raise ImportError(tiktoken_import_err) enc = tiktoken.get_encoding("gpt2") self._tokenizer = cast(Callable[[str], List], enc.encode) self._tokenizer = partial(self._tokenizer, allowed_special="all") return self._tokenizer # type: ignore @property def stopwords(self) -> List[str]: """Get stopwords.""" if self._stopwords is None: try: import nltk from nltk.corpus import stopwords except ImportError: raise ImportError( "`nltk` package not found, please run `pip install nltk`" ) from llama_index.utils import get_cache_dir cache_dir = get_cache_dir() nltk_data_dir = os.environ.get("NLTK_DATA", cache_dir) # update nltk path for nltk so that it finds the data if nltk_data_dir not in nltk.data.path: nltk.data.path.append(nltk_data_dir) try: nltk.data.find("corpora/stopwords") except LookupError: nltk.download("stopwords", download_dir=nltk_data_dir) self._stopwords = stopwords.words("english") return self._stopwords globals_helper = GlobalsHelper() def get_new_id(d: Set) -> str: """Get a new ID.""" while True: new_id = str(uuid.uuid4()) if new_id not in d: break return new_id def get_new_int_id(d: Set) -> int: """Get a new integer ID.""" while True: new_id = random.randint(0, sys.maxsize) if new_id not in d: break return new_id @contextmanager def temp_set_attrs(obj: Any, **kwargs: Any) -> Generator: """Temporary setter. Utility class for setting a temporary value for an attribute on a class. Taken from: https://tinyurl.com/2p89xymh """ prev_values = {k: getattr(obj, k) for k in kwargs} for k, v in kwargs.items(): setattr(obj, k, v) try: yield finally: for k, v in prev_values.items(): setattr(obj, k, v) @dataclass class ErrorToRetry: """Exception types that should be retried. Args: exception_cls (Type[Exception]): Class of exception. check_fn (Optional[Callable[[Any]], bool]]): A function that takes an exception instance as input and returns whether to retry. """ exception_cls: Type[Exception] check_fn: Optional[Callable[[Any], bool]] = None def retry_on_exceptions_with_backoff( lambda_fn: Callable, errors_to_retry: List[ErrorToRetry], max_tries: int = 10, min_backoff_secs: float = 0.5, max_backoff_secs: float = 60.0, ) -> Any: """Execute lambda function with retries and exponential backoff. Args: lambda_fn (Callable): Function to be called and output we want. errors_to_retry (List[ErrorToRetry]): List of errors to retry. At least one needs to be provided. max_tries (int): Maximum number of tries, including the first. Defaults to 10. min_backoff_secs (float): Minimum amount of backoff time between attempts. Defaults to 0.5. max_backoff_secs (float): Maximum amount of backoff time between attempts. Defaults to 60. """ if not errors_to_retry: raise ValueError("At least one error to retry needs to be provided") error_checks = { error_to_retry.exception_cls: error_to_retry.check_fn for error_to_retry in errors_to_retry } exception_class_tuples = tuple(error_checks.keys()) backoff_secs = min_backoff_secs tries = 0 while True: try: return lambda_fn() except exception_class_tuples as e: traceback.print_exc() tries += 1 if tries >= max_tries: raise check_fn = error_checks.get(e.__class__) if check_fn and not check_fn(e): raise time.sleep(backoff_secs) backoff_secs = min(backoff_secs * 2, max_backoff_secs) def truncate_text(text: str, max_length: int) -> str: """Truncate text to a maximum length.""" if len(text) <= max_length: return text return text[: max_length - 3] + "..." def iter_batch(iterable: Union[Iterable, Generator], size: int) -> Iterable: """Iterate over an iterable in batches. >>> list(iter_batch([1, 2, 3, 4, 5], 3)) [[1, 2, 3], [4, 5]] """ source_iter = iter(iterable) while source_iter: b = list(islice(source_iter, size)) if len(b) == 0: break yield b def concat_dirs(dirname: str, basename: str) -> str: """ Append basename to dirname, avoiding backslashes when running on windows. os.path.join(dirname, basename) will add a backslash before dirname if basename does not end with a slash, so we make sure it does. """ dirname += "/" if dirname[-1] != "/" else "" return os.path.join(dirname, basename) def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterable: """ Optionally get a tqdm iterable. Ensures tqdm.auto is used. """ _iterator = items if show_progress: try: from tqdm.auto import tqdm return tqdm(items, desc=desc) except ImportError: pass return _iterator def count_tokens(text: str) -> int: tokens = globals_helper.tokenizer(text) return len(tokens) def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]: """ Args: model_name(str): the model name of the tokenizer. For instance, fxmarty/tiny-llama-fast-tokenizer. """ try: from transformers import AutoTokenizer except ImportError: raise ValueError( "`transformers` package not found, please run `pip install transformers`" ) tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer.tokenize def get_cache_dir() -> str: """Locate a platform-appropriate cache directory for llama_index, and create it if it doesn't yet exist. """ # User override if "LLAMA_INDEX_CACHE_DIR" in os.environ: path = Path(os.environ["LLAMA_INDEX_CACHE_DIR"]) # Linux, Unix, AIX, etc. elif os.name == "posix" and sys.platform != "darwin": path = Path("/tmp/llama_index") # Mac OS elif sys.platform == "darwin": path = Path(os.path.expanduser("~"), "Library/Caches/llama_index") # Windows (hopefully) else: local = os.environ.get("LOCALAPPDATA", None) or os.path.expanduser( "~\\AppData\\Local" ) path = Path(local, "llama_index") if not os.path.exists(path): os.makedirs( path, exist_ok=True ) # prevents https://github.com/jerryjliu/llama_index/issues/7362 return str(path) def add_sync_version(func: Any) -> Any: """Decorator for adding sync version of an async function. The sync version is added as a function attribute to the original function, func. Args: func(Any): the async function for which a sync variant will be built. """ assert asyncio.iscoroutinefunction(func) @wraps(func) def _wrapper(*args: Any, **kwds: Any) -> Any: return asyncio.get_event_loop().run_until_complete(func(*args, **kwds)) func.sync = _wrapper return func # Sample text from llama_index's readme SAMPLE_TEXT = """ Context LLMs are a phenomenal piece of technology for knowledge generation and reasoning. They are pre-trained on large amounts of publicly available data. How do we best augment LLMs with our own private data? We need a comprehensive toolkit to help perform this data augmentation for LLMs. Proposed Solution That's where LlamaIndex comes in. LlamaIndex is a "data framework" to help you build LLM apps. It provides the following tools: Offers data connectors to ingest your existing data sources and data formats (APIs, PDFs, docs, SQL, etc.) Provides ways to structure your data (indices, graphs) so that this data can be easily used with LLMs. Provides an advanced retrieval/query interface over your data: Feed in any LLM input prompt, get back retrieved context and knowledge-augmented output. Allows easy integrations with your outer application framework (e.g. with LangChain, Flask, Docker, ChatGPT, anything else). LlamaIndex provides tools for both beginner users and advanced users. Our high-level API allows beginner users to use LlamaIndex to ingest and query their data in 5 lines of code. Our lower-level APIs allow advanced users to customize and extend any module (data connectors, indices, retrievers, query engines, reranking modules), to fit their needs. """ _LLAMA_INDEX_COLORS = { "llama_pink": "38;2;237;90;200", "llama_blue": "38;2;90;149;237", "llama_turquoise": "38;2;11;159;203", "llama_lavender": "38;2;155;135;227", } _ANSI_COLORS = { "red": "31", "green": "32", "yellow": "33", "blue": "34", "magenta": "35", "cyan": "36", "pink": "38;5;200", } def get_color_mapping( items: List[str], use_llama_index_colors: bool = True ) -> Dict[str, str]: """ Get a mapping of items to colors. Args: items (List[str]): List of items to be mapped to colors. use_llama_index_colors (bool, optional): Flag to indicate whether to use LlamaIndex colors or ANSI colors. Defaults to True. Returns: Dict[str, str]: Mapping of items to colors. """ if use_llama_index_colors: color_palette = _LLAMA_INDEX_COLORS else: color_palette = _ANSI_COLORS colors = list(color_palette.keys()) return {item: colors[i % len(colors)] for i, item in enumerate(items)} def _get_colored_text(text: str, color: str) -> str: """ Get the colored version of the input text. Args: text (str): Input text. color (str): Color to be applied to the text. Returns: str: Colored version of the input text. """ all_colors = {**_LLAMA_INDEX_COLORS, **_ANSI_COLORS} if color not in all_colors: return f"\033[1;3m{text}\033[0m" # just bolded and italicized color = all_colors[color] return f"\033[1;3;{color}m{text}\033[0m" def print_text(text: str, color: Optional[str] = None, end: str = "") -> None: """ Print the text with the specified color. Args: text (str): Text to be printed. color (str, optional): Color to be applied to the text. Supported colors are: llama_pink, llama_blue, llama_turquoise, llama_lavender, red, green, yellow, blue, magenta, cyan, pink. end (str, optional): String appended after the last character of the text. Returns: None """ text_to_print = _get_colored_text(text, color) if color is not None else text print(text_to_print, end=end) def infer_torch_device() -> str: """Infer the input to torch.device.""" try: has_cuda = torch.cuda.is_available() except NameError: import torch has_cuda = torch.cuda.is_available() if has_cuda: return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" def unit_generator(x: Any) -> Generator[Any, None, None]: """A function that returns a generator of a single element. Args: x (Any): the element to build yield Yields: Any: the single element """ yield x async def async_unit_generator(x: Any) -> AsyncGenerator[Any, None]: """A function that returns a generator of a single element. Args: x (Any): the element to build yield Yields: Any: the single element """ yield x