core[patch[: docstring update (#21036)

Added missed docstrings. Updated docstrings to consistent format.
This commit is contained in:
Leonid Ganeline 2024-04-29 12:35:34 -07:00 committed by GitHub
parent f479a337cc
commit 1a2ff56cd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 159 additions and 38 deletions

View File

@ -73,10 +73,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
"""Example selector that selects examples based on SemanticSimilarity."""
"""Select examples based on semantic similarity."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
"""Select examples based on semantic similarity."""
# Get the docs with the highest similarity.
vectorstore_kwargs = self.vectorstore_kwargs or {}
example_docs = self.vectorstore.similarity_search(
@ -87,7 +87,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity."""
"""Asynchronously select examples based on semantic similarity."""
# Get the docs with the highest similarity.
vectorstore_kwargs = self.vectorstore_kwargs or {}
example_docs = await self.vectorstore.asimilarity_search(
@ -187,7 +187,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
"""ExampleSelector that selects examples based on Max Marginal Relevance.
"""Select examples based on Max Marginal Relevance.
This was shown to improve performance in this paper:
https://arxiv.org/pdf/2211.13892.pdf
@ -197,6 +197,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
"""Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select examples based on Max Marginal Relevance.
Args:
input_variables: The input variables to use for search.
Returns:
The selected examples.
"""
example_docs = self.vectorstore.max_marginal_relevance_search(
self._example_to_text(input_variables, self.input_keys),
k=self.k,
@ -205,6 +213,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Asynchronously select examples based on Max Marginal Relevance.
Args:
input_variables: The input variables to use for search.
Returns:
The selected examples.
"""
example_docs = await self.vectorstore.amax_marginal_relevance_search(
self._example_to_text(input_variables, self.input_keys),
k=self.k,
@ -272,7 +288,8 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector:
"""Create k-shot example selector using example list and embeddings.
"""Asynchronously create k-shot example selector using example list and
embeddings.
Reshuffles examples dynamically based on Max Marginal Relevance.

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Sequence
class RecordManager(ABC):
"""An abstract base class representing the interface for a record manager."""
"""Abstract base class representing the interface for a record manager."""
def __init__(
self,
@ -24,7 +24,7 @@ class RecordManager(ABC):
@abstractmethod
async def acreate_schema(self) -> None:
"""Create the database schema for the record manager."""
"""Asynchronously create the database schema for the record manager."""
@abstractmethod
def get_time(self) -> float:
@ -39,7 +39,7 @@ class RecordManager(ABC):
@abstractmethod
async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp!
"""Asynchronously get the current server time as a high resolution timestamp.
It's important to get this from the server to ensure a monotonic clock,
otherwise there may be data loss when cleaning up old documents!
@ -84,7 +84,7 @@ class RecordManager(ABC):
group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None,
) -> None:
"""Upsert records into the database.
"""Asynchronously upsert records into the database.
Args:
keys: A list of record keys to upsert.
@ -117,7 +117,7 @@ class RecordManager(ABC):
@abstractmethod
async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database.
"""Asynchronously check if the provided keys exist in the database.
Args:
keys: A list of keys to check.
@ -156,7 +156,7 @@ class RecordManager(ABC):
group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None,
) -> List[str]:
"""List records in the database based on the provided filters.
"""Asynchronously list records in the database based on the provided filters.
Args:
before: Filter to list records updated before this time.
@ -178,7 +178,7 @@ class RecordManager(ABC):
@abstractmethod
async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database.
"""Asynchronously delete specified records from the database.
Args:
keys: A list of keys to delete.

View File

@ -913,7 +913,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
class SimpleChatModel(BaseChatModel):
"""A simplified implementation for a chat model to inherit from."""
"""Simplified implementation for a chat model to inherit from."""
def _generate(
self,

View File

@ -151,7 +151,7 @@ class FakeChatModel(SimpleChatModel):
class GenericFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
* Invokes on_llm_new_token to allow for testing of callback related code for new
@ -288,7 +288,7 @@ class GenericFakeChatModel(BaseChatModel):
class ParrotFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
"""Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""

View File

@ -1218,7 +1218,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
class LLM(BaseLLM):
"""This class exposes a simple interface for implementing a custom LLM.
"""Simple interface for implementing a custom LLM.
You should subclass this class and implement the following:

View File

@ -90,6 +90,8 @@ class ChatPromptValue(PromptValue):
class ImageURL(TypedDict, total=False):
"""Image URL."""
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""

View File

@ -8,7 +8,7 @@ from langchain_core.utils import image as image_utils
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""An image prompt template for a multimodal model."""
"""Image prompt template for a multimodal model."""
template: dict = Field(default_factory=dict)
"""Template for the prompt."""

View File

@ -17,7 +17,7 @@ from langchain_core.runnables.config import RunnableConfig
class PromptTemplate(StringPromptTemplate):
"""A prompt template for a language model.
"""Prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.

View File

@ -33,7 +33,10 @@ from langchain_core.runnables.base import (
@beta()
class StructuredPrompt(ChatPromptTemplate):
"""Structured prompt template for a language model."""
schema_: Union[Dict, Type[BaseModel]]
"""Schema for the structured prompt."""
@classmethod
def get_lc_namespace(cls) -> List[str]:

View File

@ -641,7 +641,8 @@ def make_options_spec(
description: Optional[str],
) -> ConfigurableFieldSpec:
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
ConfigurableFieldMultiOption."""
ConfigurableFieldMultiOption.
"""
with _enums_for_spec_lock:
if enum := _enums_for_spec.get(spec):
pass

View File

@ -26,11 +26,23 @@ if TYPE_CHECKING:
class LabelsDict(TypedDict):
"""Dictionary of labels for nodes and edges in a graph."""
nodes: dict[str, str]
"""Labels for nodes."""
edges: dict[str, str]
"""Labels for edges."""
def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID.
Args:
value: The string to check.
Returns:
True if the string is a valid UUID, False otherwise.
"""
try:
UUID(value)
return True
@ -95,6 +107,14 @@ class MermaidDrawMethod(Enum):
def node_data_str(node: Node) -> str:
"""Convert the data of a node to a string.
Args:
node: The node to convert.
Returns:
A string representation of the data.
"""
from langchain_core.runnables.base import Runnable
if not is_uuid(node.id):
@ -120,6 +140,16 @@ def node_data_str(node: Node) -> str:
def node_data_json(
node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]:
"""Convert the data of a node to a JSON-serializable format.
Args:
node: The node to convert.
with_schemas: Whether to include the schema of the data if
it is a Pydantic model.
Returns:
A dictionary with the type of the data and the data itself.
"""
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable

View File

@ -4,9 +4,9 @@ from langchain_core.runnables.graph import Graph, LabelsDict
class PngDrawer:
"""
A helper class to draw a state graph into a PNG file.
Requires graphviz and pygraphviz to be installed.
"""Helper class to draw a state graph into a PNG file.
It requires graphviz and pygraphviz to be installed.
:param fontname: The font to use for the labels
:param labels: A dictionary of label overrides. The dictionary
should have the following format:
@ -30,18 +30,62 @@ class PngDrawer:
def __init__(
self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None
) -> None:
"""Initializes the PNG drawer.
Args:
fontname: The font to use for the labels
labels: A dictionary of label overrides. The dictionary
should have the following format:
{
"nodes": {
"node1": "CustomLabel1",
"node2": "CustomLabel2",
"__end__": "End Node"
},
"edges": {
"continue": "ContinueLabel",
"end": "EndLabel"
}
}
The keys are the original labels, and the values are the new labels.
"""
self.fontname = fontname or "arial"
self.labels = labels or LabelsDict(nodes={}, edges={})
def get_node_label(self, label: str) -> str:
"""Returns the label to use for a node.
Args:
label: The original label
Returns:
The new label.
"""
label = self.labels.get("nodes", {}).get(label, label)
return f"<<B>{label}</B>>"
def get_edge_label(self, label: str) -> str:
"""Returns the label to use for an edge.
Args:
label: The original label
Returns:
The new label.
"""
label = self.labels.get("edges", {}).get(label, label)
return f"<<U>{label}</U>>"
def add_node(self, viz: Any, node: str) -> None:
"""Adds a node to the graph.
Args:
viz: The graphviz object
node: The node to add
Returns:
None
"""
viz.add_node(
node,
label=self.get_node_label(node),
@ -59,6 +103,18 @@ class PngDrawer:
label: Optional[str] = None,
conditional: bool = False,
) -> None:
"""Adds an edge to the graph.
Args:
viz: The graphviz object
source: The source node
target: The target node
label: The label for the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
Returns:
None
"""
viz.add_edge(
source,
target,
@ -69,8 +125,8 @@ class PngDrawer:
)
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
"""
Draws the given state graph into a PNG file.
"""Draw the given state graph into a PNG file.
Requires graphviz and pygraphviz to be installed.
:param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned.

View File

@ -329,8 +329,7 @@ _graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
"""
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
"""Runnable that assigns key-value pairs to Dict[str, Any] inputs.
The `RunnableAssign` class takes input dictionaries and, through a
`RunnableParallel` instance, applies transformations, then combines

View File

@ -507,6 +507,15 @@ def create_model(
__model_name: str,
**field_definitions: Any,
) -> Type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Args:
__model_name: The name of the model.
**field_definitions: The field definitions for the model.
Returns:
Type[BaseModel]: The created model.
"""
try:
return _create_model_cached(__model_name, **field_definitions)
except TypeError:

View File

@ -93,11 +93,11 @@ class Comparator(str, Enum):
class FilterDirective(Expr, ABC):
"""A filtering expression."""
"""Filtering expression."""
class Comparison(FilterDirective):
"""A comparison to a value."""
"""Comparison to a value."""
comparator: Comparator
attribute: str
@ -112,7 +112,7 @@ class Comparison(FilterDirective):
class Operation(FilterDirective):
"""A logical operation over other directives."""
"""Llogical operation over other directives."""
operator: Operator
arguments: List[FilterDirective]
@ -124,7 +124,7 @@ class Operation(FilterDirective):
class StructuredQuery(Expr):
"""A structured query."""
"""Structured query."""
query: str
"""Query string."""

View File

@ -43,7 +43,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Throws an error because this has been replaced by tracing_v2_enabled."""
"""Throw an error because this has been replaced by tracing_v2_enabled."""
raise RuntimeError(
"tracing_enabled is no longer supported. Please use tracing_enabled_v2 instead."
)

View File

@ -2,6 +2,7 @@ from typing import Any
def get_headers(*args: Any, **kwargs: Any) -> Any:
"""Throw an error because this has been replaced by get_headers."""
raise RuntimeError(
"get_headers for LangChainTracerV1 is no longer supported. "
"Please use LangChainTracer instead."
@ -9,6 +10,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
"""Throw an error because this has been replaced by LangChainTracer."""
raise RuntimeError(
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
)

View File

@ -120,7 +120,7 @@ async def tee_peer(
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.

View File

@ -31,6 +31,8 @@ _LAST_TAG_LINE = None
class ChevronError(SyntaxError):
"""Custom exception for Chevron errors."""
pass
@ -40,7 +42,7 @@ class ChevronError(SyntaxError):
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template"""
"""Parse a literal from the template."""
global _CURRENT_LINE
@ -57,7 +59,7 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone"""
"""Do a preliminary check to see if a tag could be a standalone."""
# If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone:
@ -75,7 +77,7 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final checkto see if a tag could be a standalone"""
"""Do a final check to see if a tag could be a standalone."""
# Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]:
@ -93,7 +95,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template"""
"""Parse a tag from a template."""
global _CURRENT_LINE
global _LAST_TAG_LINE
@ -157,7 +159,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
def tokenize(
template: str, def_ldel: str = "{{", def_rdel: str = "}}"
) -> Iterator[Tuple[str, str]]:
"""Tokenize a mustache template
"""Tokenize a mustache template.
Tokenizes a mustache template in a generator fashion,
using file-like objects. It also accepts a string containing

View File

@ -84,7 +84,7 @@ def mock_now(dt_value): # type: ignore
def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not
"""Dynamically import a module and raise an exception if the module is not
installed."""
try:
module = importlib.import_module(module_name, package)