core: docstrings indexing (#23785)

Added missed docstrings. Formatted docstrings to the consistent form.
This commit is contained in:
Leonid Ganeline 2024-07-03 08:27:34 -07:00 committed by GitHub
parent 30fdc2dbe7
commit 716a316654
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 310 additions and 33 deletions

View File

@ -4,6 +4,7 @@ This package contains helper logic to help deal with indexing data into
a vectorstore while avoiding duplicated content and over-writing content
if it's unchanged.
"""
from langchain_core.indexing.api import IndexingResult, aindex, index
from langchain_core.indexing.base import InMemoryRecordManager, RecordManager

View File

@ -1,4 +1,5 @@
"""Module contains logic for indexing documents into vector stores."""
from __future__ import annotations
import hashlib
@ -232,8 +233,8 @@ def index(
record_manager: Timestamped set to keep track of which documents were
updated.
vector_store: Vector store to index the documents into.
batch_size: Batch size to use when indexing.
cleanup: How to handle clean up of documents.
batch_size: Batch size to use when indexing. Default is 100.
cleanup: How to handle clean up of documents. Default is None.
- Incremental: Cleans up all documents that haven't been updated AND
that are associated with source ids that were seen
during indexing.
@ -246,14 +247,23 @@ def index(
This means that users may see duplicated content during indexing.
- None: Do not delete any documents.
source_id_key: Optional key that helps identify the original source
of the document.
of the document. Default is None.
cleanup_batch_size: Batch size to use when cleaning up documents.
Default is 1_000.
force_update: Force update documents even if they are present in the
record manager. Useful if you are re-indexing with updated embeddings.
Default is False.
Returns:
Indexing result which contains information about how many documents
were added, updated, deleted, or skipped.
Raises:
ValueError: If cleanup mode is not one of 'incremental', 'full' or None
ValueError: If cleanup mode is incremental and source_id_key is None.
ValueError: If vectorstore does not have
"delete" and "add_documents" required methods.
ValueError: If source_id_key is not None, but is not a string or callable.
"""
if cleanup not in {"incremental", "full", None}:
raise ValueError(
@ -415,7 +425,7 @@ async def aindex(
cleanup_batch_size: int = 1_000,
force_update: bool = False,
) -> IndexingResult:
"""Index data from the loader into the vector store.
"""Async index data from the loader into the vector store.
Indexing functionality uses a manager to keep track of which documents
are in the vector store.
@ -437,8 +447,8 @@ async def aindex(
record_manager: Timestamped set to keep track of which documents were
updated.
vector_store: Vector store to index the documents into.
batch_size: Batch size to use when indexing.
cleanup: How to handle clean up of documents.
batch_size: Batch size to use when indexing. Default is 100.
cleanup: How to handle clean up of documents. Default is None.
- Incremental: Cleans up all documents that haven't been updated AND
that are associated with source ids that were seen
during indexing.
@ -450,14 +460,23 @@ async def aindex(
This means that users may see duplicated content during indexing.
- None: Do not delete any documents.
source_id_key: Optional key that helps identify the original source
of the document.
of the document. Default is None.
cleanup_batch_size: Batch size to use when cleaning up documents.
Default is 1_000.
force_update: Force update documents even if they are present in the
record manager. Useful if you are re-indexing with updated embeddings.
Default is False.
Returns:
Indexing result which contains information about how many documents
were added, updated, deleted, or skipped.
Raises:
ValueError: If cleanup mode is not one of 'incremental', 'full' or None
ValueError: If cleanup mode is incremental and source_id_key is None.
ValueError: If vectorstore does not have
"adelete" and "aadd_documents" required methods.
ValueError: If source_id_key is not None, but is not a string or callable.
"""
if cleanup not in {"incremental", "full", None}:

View File

@ -37,7 +37,7 @@ class RecordManager(ABC):
2. The record manager is currently implemented separately from the
vectorstore, which means that the overall system becomes distributed
and may create issues with consistency. For example, writing to
record manager succeeds but corresponding writing to vectorstore fails.
record manager succeeds, but corresponding writing to vectorstore fails.
"""
def __init__(
@ -227,6 +227,11 @@ class InMemoryRecordManager(RecordManager):
"""An in-memory record manager for testing purposes."""
def __init__(self, namespace: str) -> None:
"""Initialize the in-memory record manager.
Args:
namespace (str): The namespace for the record manager.
"""
super().__init__(namespace)
# Each key points to a dictionary
# of {'group_id': group_id, 'updated_at': timestamp}
@ -237,14 +242,16 @@ class InMemoryRecordManager(RecordManager):
"""In-memory schema creation is simply ensuring the structure is initialized."""
async def acreate_schema(self) -> None:
"""In-memory schema creation is simply ensuring the structure is initialized."""
"""Async in-memory schema creation is simply ensuring
the structure is initialized.
"""
def get_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
return time.time()
async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!"""
"""Async get the current server time as a high resolution timestamp!"""
return self.get_time()
def update(
@ -254,6 +261,27 @@ class InMemoryRecordManager(RecordManager):
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the database.
Args:
keys: A list of record keys to upsert.
group_ids: A list of group IDs corresponding to the keys.
Defaults to None.
time_at_least: Optional timestamp. Implementation can use this
to optionally verify that the timestamp IS at least this time
in the system that stores. Defaults to None.
E.g., use to validate that the time in the postgres database
is equal to or larger than the given timestamp, if not
raise an error.
This is meant to help prevent time-drift issues since
time may not be monotonically increasing!
Raises:
ValueError: If the length of keys doesn't match the length of group
ids.
ValueError: If time_at_least is in the future.
"""
if group_ids and len(keys) != len(group_ids):
raise ValueError("Length of keys must match length of group_ids")
for index, key in enumerate(keys):
@ -269,12 +297,48 @@ class InMemoryRecordManager(RecordManager):
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Async upsert records into the database.
Args:
keys: A list of record keys to upsert.
group_ids: A list of group IDs corresponding to the keys.
Defaults to None.
time_at_least: Optional timestamp. Implementation can use this
to optionally verify that the timestamp IS at least this time
in the system that stores. Defaults to None.
E.g., use to validate that the time in the postgres database
is equal to or larger than the given timestamp, if not
raise an error.
This is meant to help prevent time-drift issues since
time may not be monotonically increasing!
Raises:
ValueError: If the length of keys doesn't match the length of group
ids.
ValueError: If time_at_least is in the future.
"""
self.update(keys, group_ids=group_ids, time_at_least=time_at_least)
def exists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database.
Args:
keys: A list of keys to check.
Returns:
A list of boolean values indicating the existence of each key.
"""
return [key in self.records for key in keys]
async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Async check if the provided keys exist in the database.
Args:
keys: A list of keys to check.
Returns:
A list of boolean values indicating the existence of each key.
"""
return self.exists(keys)
def list_keys(
@ -285,6 +349,21 @@ class InMemoryRecordManager(RecordManager):
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the database based on the provided filters.
Args:
before: Filter to list records updated before this time.
Defaults to None.
after: Filter to list records updated after this time.
Defaults to None.
group_ids: Filter to list records with specific group IDs.
Defaults to None.
limit: optional limit on the number of records to return.
Defaults to None.
Returns:
A list of keys for the matching records.
"""
result = []
for key, data in self.records.items():
if before and data["updated_at"] >= before:
@ -306,14 +385,39 @@ class InMemoryRecordManager(RecordManager):
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""Async list records in the database based on the provided filters.
Args:
before: Filter to list records updated before this time.
Defaults to None.
after: Filter to list records updated after this time.
Defaults to None.
group_ids: Filter to list records with specific group IDs.
Defaults to None.
limit: optional limit on the number of records to return.
Defaults to None.
Returns:
A list of keys for the matching records.
"""
return self.list_keys(
before=before, after=after, group_ids=group_ids, limit=limit
)
def delete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database.
Args:
keys: A list of keys to delete.
"""
for key in keys:
if key in self.records:
del self.records[key]
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Async delete specified records from the database.
Args:
keys: A list of keys to delete.
"""
self.delete_keys(keys)

View File

@ -39,6 +39,11 @@ if TYPE_CHECKING:
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
"""Get a GPT-2 tokenizer instance.
This function is cached to avoid re-loading the tokenizer
every time it is called.
"""
try:
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
@ -77,7 +82,7 @@ class BaseLanguageModel(
):
"""Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel.
All language model wrappers inherited from BaseLanguageModel.
"""
cache: Union[BaseCache, bool, None] = None
@ -108,6 +113,12 @@ class BaseLanguageModel(
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
Args:
verbose: The verbosity setting to use.
Returns:
The verbosity setting to use.
"""
if verbose is None:
return _get_verbosity()
@ -324,7 +335,7 @@ class BaseLanguageModel(
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.
Useful for checking if an input will fit in a model's context window.
Useful for checking if an input fits in a model's context window.
Args:
text: The string input to tokenize.
@ -337,7 +348,7 @@ class BaseLanguageModel(
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Useful for checking if an input fits in a model's context window.
Args:
messages: The message inputs to tokenize.

View File

@ -68,16 +68,31 @@ if TYPE_CHECKING:
class LangSmithParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""
ls_provider: str
"""Provider of the model."""
ls_model_name: str
"""Name of the model."""
ls_model_type: Literal["chat"]
"""Type of the model. Should be 'chat'."""
ls_temperature: Optional[float]
"""Temperature for generation."""
ls_max_tokens: Optional[int]
"""Max tokens for generation."""
ls_stop: Optional[List[str]]
"""Stop words for generation."""
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream."""
"""Generate from a stream.
Args:
stream: Iterator of ChatGenerationChunk.
Returns:
ChatResult: Chat result.
"""
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
@ -99,7 +114,14 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
async def agenerate_from_stream(
stream: AsyncIterator[ChatGenerationChunk],
) -> ChatResult:
"""Async generate from a stream."""
"""Async generate from a stream.
Args:
stream: Iterator of ChatGenerationChunk.
Returns:
ChatResult: Chat result.
"""
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
@ -200,7 +222,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
"""Raise deprecation warning if callback_manager is used.
Args:
values (Dict): Values to validate.
Returns:
Dict: Validated values.
Raises:
DeprecationWarning: If callback_manager is used.
"""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",

View File

@ -1,4 +1,5 @@
"""Fake ChatModel for testing purposes."""
import asyncio
import re
import time

View File

@ -78,7 +78,20 @@ def create_base_retry_decorator(
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Create a retry decorator for a given LLM and provided list of error types."""
"""Create a retry decorator for a given LLM and provided
a list of error types.
Args:
error_types: List of error types to retry on.
max_retries: Number of retries. Default is 1.
run_manager: Callback manager for the run. Default is None.
Returns:
A retry decorator.
Raises:
ValueError: If the cache is not set and cache is True.
"""
_logging = before_sleep_log(logger, logging.WARNING)
@ -141,7 +154,20 @@ def get_prompts(
prompts: List[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached."""
"""Get prompts that are already cached.
Args:
params: Dictionary of parameters.
prompts: List of prompts.
cache: Cache object. Default is None.
Returns:
A tuple of existing prompts, llm_string, missing prompt indexes,
and missing prompts.
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
@ -164,7 +190,20 @@ async def aget_prompts(
prompts: List[str],
cache: Optional[Union[BaseCache, bool, None]] = None,
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached. Async version."""
"""Get prompts that are already cached. Async version.
Args:
params: Dictionary of parameters.
prompts: List of prompts.
cache: Cache object. Default is None.
Returns:
A tuple of existing prompts, llm_string, missing prompt indexes,
and missing prompts.
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
@ -189,7 +228,22 @@ def update_cache(
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output."""
"""Update the cache and get the LLM output.
Args:
cache: Cache object.
existing_prompts: Dictionary of existing prompts.
llm_string: LLM string.
missing_prompt_idxs: List of missing prompt indexes.
new_results: LLMResult object.
prompts: List of prompts.
Returns:
LLM output.
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
@ -208,7 +262,23 @@ async def aupdate_cache(
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version"""
"""Update the cache and get the LLM output. Async version.
Args:
cache: Cache object.
existing_prompts: Dictionary of existing prompts.
llm_string: LLM string.
missing_prompt_idxs: List of missing prompt indexes.
new_results: LLMResult object.
prompts: List of prompts.
Returns:
LLM output.
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
@ -706,6 +776,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
tags: List of tags to associate with each prompt. If provided, the length
of the list must match the length of the prompts list.
metadata: List of metadata dictionaries to associate with each prompt. If
provided, the length of the list must match the length of the prompts
list.
run_name: List of run names to associate with each prompt. If provided, the
length of the list must match the length of the prompts list.
run_id: List of run IDs to associate with each prompt. If provided, the
length of the list must match the length of the prompts list.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
@ -924,6 +1003,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
tags: List of tags to associate with each prompt. If provided, the length
of the list must match the length of the prompts list.
metadata: List of metadata dictionaries to associate with each prompt. If
provided, the length of the list must match the length of the prompts
list.
run_name: List of run names to associate with each prompt. If provided, the
length of the list must match the length of the prompts list.
run_id: List of run IDs to associate with each prompt. If provided, the
length of the list must match the length of the prompts list.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
@ -1075,7 +1163,25 @@ class BaseLLM(BaseLanguageModel[str], ABC):
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
"""Check Cache and run the LLM on the given prompt and input.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
callbacks: Callbacks to pass through. Used for executing additional
functionality, such as logging or streaming, throughout generation.
tags: List of tags to associate with the prompt.
metadata: Metadata to associate with the prompt.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The generated text.
Raises:
ValueError: If the prompt is not a string.
"""
if not isinstance(prompt, str):
raise ValueError(
"Argument `prompt` is expected to be a string. Instead found "
@ -1190,6 +1296,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
Args:
file_path: Path to file to save the LLM to.
Raises:
ValueError: If the file path is not a string or Path object.
Example:
.. code-block:: python
@ -1333,7 +1442,7 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
"""Async run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts: