core[patch[: docstring update (#21036)

Added missed docstrings. Updated docstrings to consistent format.
This commit is contained in:
Leonid Ganeline 2024-04-29 12:35:34 -07:00 committed by GitHub
parent f479a337cc
commit 1a2ff56cd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 159 additions and 38 deletions

View File

@ -73,10 +73,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
"""Example selector that selects examples based on SemanticSimilarity.""" """Select examples based on semantic similarity."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity.""" """Select examples based on semantic similarity."""
# Get the docs with the highest similarity. # Get the docs with the highest similarity.
vectorstore_kwargs = self.vectorstore_kwargs or {} vectorstore_kwargs = self.vectorstore_kwargs or {}
example_docs = self.vectorstore.similarity_search( example_docs = self.vectorstore.similarity_search(
@ -87,7 +87,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
return self._documents_to_examples(example_docs) return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on semantic similarity.""" """Asynchronously select examples based on semantic similarity."""
# Get the docs with the highest similarity. # Get the docs with the highest similarity.
vectorstore_kwargs = self.vectorstore_kwargs or {} vectorstore_kwargs = self.vectorstore_kwargs or {}
example_docs = await self.vectorstore.asimilarity_search( example_docs = await self.vectorstore.asimilarity_search(
@ -187,7 +187,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
"""ExampleSelector that selects examples based on Max Marginal Relevance. """Select examples based on Max Marginal Relevance.
This was shown to improve performance in this paper: This was shown to improve performance in this paper:
https://arxiv.org/pdf/2211.13892.pdf https://arxiv.org/pdf/2211.13892.pdf
@ -197,6 +197,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
"""Number of examples to fetch to rerank.""" """Number of examples to fetch to rerank."""
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select examples based on Max Marginal Relevance.
Args:
input_variables: The input variables to use for search.
Returns:
The selected examples.
"""
example_docs = self.vectorstore.max_marginal_relevance_search( example_docs = self.vectorstore.max_marginal_relevance_search(
self._example_to_text(input_variables, self.input_keys), self._example_to_text(input_variables, self.input_keys),
k=self.k, k=self.k,
@ -205,6 +213,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
return self._documents_to_examples(example_docs) return self._documents_to_examples(example_docs)
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Asynchronously select examples based on Max Marginal Relevance.
Args:
input_variables: The input variables to use for search.
Returns:
The selected examples.
"""
example_docs = await self.vectorstore.amax_marginal_relevance_search( example_docs = await self.vectorstore.amax_marginal_relevance_search(
self._example_to_text(input_variables, self.input_keys), self._example_to_text(input_variables, self.input_keys),
k=self.k, k=self.k,
@ -272,7 +288,8 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
vectorstore_kwargs: Optional[dict] = None, vectorstore_kwargs: Optional[dict] = None,
**vectorstore_cls_kwargs: Any, **vectorstore_cls_kwargs: Any,
) -> MaxMarginalRelevanceExampleSelector: ) -> MaxMarginalRelevanceExampleSelector:
"""Create k-shot example selector using example list and embeddings. """Asynchronously create k-shot example selector using example list and
embeddings.
Reshuffles examples dynamically based on Max Marginal Relevance. Reshuffles examples dynamically based on Max Marginal Relevance.

View File

@ -5,7 +5,7 @@ from typing import List, Optional, Sequence
class RecordManager(ABC): class RecordManager(ABC):
"""An abstract base class representing the interface for a record manager.""" """Abstract base class representing the interface for a record manager."""
def __init__( def __init__(
self, self,
@ -24,7 +24,7 @@ class RecordManager(ABC):
@abstractmethod @abstractmethod
async def acreate_schema(self) -> None: async def acreate_schema(self) -> None:
"""Create the database schema for the record manager.""" """Asynchronously create the database schema for the record manager."""
@abstractmethod @abstractmethod
def get_time(self) -> float: def get_time(self) -> float:
@ -39,7 +39,7 @@ class RecordManager(ABC):
@abstractmethod @abstractmethod
async def aget_time(self) -> float: async def aget_time(self) -> float:
"""Get the current server time as a high resolution timestamp! """Asynchronously get the current server time as a high resolution timestamp.
It's important to get this from the server to ensure a monotonic clock, It's important to get this from the server to ensure a monotonic clock,
otherwise there may be data loss when cleaning up old documents! otherwise there may be data loss when cleaning up old documents!
@ -84,7 +84,7 @@ class RecordManager(ABC):
group_ids: Optional[Sequence[Optional[str]]] = None, group_ids: Optional[Sequence[Optional[str]]] = None,
time_at_least: Optional[float] = None, time_at_least: Optional[float] = None,
) -> None: ) -> None:
"""Upsert records into the database. """Asynchronously upsert records into the database.
Args: Args:
keys: A list of record keys to upsert. keys: A list of record keys to upsert.
@ -117,7 +117,7 @@ class RecordManager(ABC):
@abstractmethod @abstractmethod
async def aexists(self, keys: Sequence[str]) -> List[bool]: async def aexists(self, keys: Sequence[str]) -> List[bool]:
"""Check if the provided keys exist in the database. """Asynchronously check if the provided keys exist in the database.
Args: Args:
keys: A list of keys to check. keys: A list of keys to check.
@ -156,7 +156,7 @@ class RecordManager(ABC):
group_ids: Optional[Sequence[str]] = None, group_ids: Optional[Sequence[str]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
) -> List[str]: ) -> List[str]:
"""List records in the database based on the provided filters. """Asynchronously list records in the database based on the provided filters.
Args: Args:
before: Filter to list records updated before this time. before: Filter to list records updated before this time.
@ -178,7 +178,7 @@ class RecordManager(ABC):
@abstractmethod @abstractmethod
async def adelete_keys(self, keys: Sequence[str]) -> None: async def adelete_keys(self, keys: Sequence[str]) -> None:
"""Delete specified records from the database. """Asynchronously delete specified records from the database.
Args: Args:
keys: A list of keys to delete. keys: A list of keys to delete.

View File

@ -913,7 +913,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
class SimpleChatModel(BaseChatModel): class SimpleChatModel(BaseChatModel):
"""A simplified implementation for a chat model to inherit from.""" """Simplified implementation for a chat model to inherit from."""
def _generate( def _generate(
self, self,

View File

@ -151,7 +151,7 @@ class FakeChatModel(SimpleChatModel):
class GenericFakeChatModel(BaseChatModel): class GenericFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface. """Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests * Chat model should be usable in both sync and async tests
* Invokes on_llm_new_token to allow for testing of callback related code for new * Invokes on_llm_new_token to allow for testing of callback related code for new
@ -288,7 +288,7 @@ class GenericFakeChatModel(BaseChatModel):
class ParrotFakeChatModel(BaseChatModel): class ParrotFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface. """Generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests * Chat model should be usable in both sync and async tests
""" """

View File

@ -1218,7 +1218,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
class LLM(BaseLLM): class LLM(BaseLLM):
"""This class exposes a simple interface for implementing a custom LLM. """Simple interface for implementing a custom LLM.
You should subclass this class and implement the following: You should subclass this class and implement the following:

View File

@ -90,6 +90,8 @@ class ChatPromptValue(PromptValue):
class ImageURL(TypedDict, total=False): class ImageURL(TypedDict, total=False):
"""Image URL."""
detail: Literal["auto", "low", "high"] detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image.""" """Specifies the detail level of the image."""

View File

@ -8,7 +8,7 @@ from langchain_core.utils import image as image_utils
class ImagePromptTemplate(BasePromptTemplate[ImageURL]): class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""An image prompt template for a multimodal model.""" """Image prompt template for a multimodal model."""
template: dict = Field(default_factory=dict) template: dict = Field(default_factory=dict)
"""Template for the prompt.""" """Template for the prompt."""

View File

@ -17,7 +17,7 @@ from langchain_core.runnables.config import RunnableConfig
class PromptTemplate(StringPromptTemplate): class PromptTemplate(StringPromptTemplate):
"""A prompt template for a language model. """Prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model. from the user that can be used to generate a prompt for a language model.

View File

@ -33,7 +33,10 @@ from langchain_core.runnables.base import (
@beta() @beta()
class StructuredPrompt(ChatPromptTemplate): class StructuredPrompt(ChatPromptTemplate):
"""Structured prompt template for a language model."""
schema_: Union[Dict, Type[BaseModel]] schema_: Union[Dict, Type[BaseModel]]
"""Schema for the structured prompt."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:

View File

@ -641,7 +641,8 @@ def make_options_spec(
description: Optional[str], description: Optional[str],
) -> ConfigurableFieldSpec: ) -> ConfigurableFieldSpec:
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or """Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
ConfigurableFieldMultiOption.""" ConfigurableFieldMultiOption.
"""
with _enums_for_spec_lock: with _enums_for_spec_lock:
if enum := _enums_for_spec.get(spec): if enum := _enums_for_spec.get(spec):
pass pass

View File

@ -26,11 +26,23 @@ if TYPE_CHECKING:
class LabelsDict(TypedDict): class LabelsDict(TypedDict):
"""Dictionary of labels for nodes and edges in a graph."""
nodes: dict[str, str] nodes: dict[str, str]
"""Labels for nodes."""
edges: dict[str, str] edges: dict[str, str]
"""Labels for edges."""
def is_uuid(value: str) -> bool: def is_uuid(value: str) -> bool:
"""Check if a string is a valid UUID.
Args:
value: The string to check.
Returns:
True if the string is a valid UUID, False otherwise.
"""
try: try:
UUID(value) UUID(value)
return True return True
@ -95,6 +107,14 @@ class MermaidDrawMethod(Enum):
def node_data_str(node: Node) -> str: def node_data_str(node: Node) -> str:
"""Convert the data of a node to a string.
Args:
node: The node to convert.
Returns:
A string representation of the data.
"""
from langchain_core.runnables.base import Runnable from langchain_core.runnables.base import Runnable
if not is_uuid(node.id): if not is_uuid(node.id):
@ -120,6 +140,16 @@ def node_data_str(node: Node) -> str:
def node_data_json( def node_data_json(
node: Node, *, with_schemas: bool = False node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]: ) -> Dict[str, Union[str, Dict[str, Any]]]:
"""Convert the data of a node to a JSON-serializable format.
Args:
node: The node to convert.
with_schemas: Whether to include the schema of the data if
it is a Pydantic model.
Returns:
A dictionary with the type of the data and the data itself.
"""
from langchain_core.load.serializable import to_json_not_implemented from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.base import Runnable, RunnableSerializable

View File

@ -4,9 +4,9 @@ from langchain_core.runnables.graph import Graph, LabelsDict
class PngDrawer: class PngDrawer:
""" """Helper class to draw a state graph into a PNG file.
A helper class to draw a state graph into a PNG file.
Requires graphviz and pygraphviz to be installed. It requires graphviz and pygraphviz to be installed.
:param fontname: The font to use for the labels :param fontname: The font to use for the labels
:param labels: A dictionary of label overrides. The dictionary :param labels: A dictionary of label overrides. The dictionary
should have the following format: should have the following format:
@ -30,18 +30,62 @@ class PngDrawer:
def __init__( def __init__(
self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None
) -> None: ) -> None:
"""Initializes the PNG drawer.
Args:
fontname: The font to use for the labels
labels: A dictionary of label overrides. The dictionary
should have the following format:
{
"nodes": {
"node1": "CustomLabel1",
"node2": "CustomLabel2",
"__end__": "End Node"
},
"edges": {
"continue": "ContinueLabel",
"end": "EndLabel"
}
}
The keys are the original labels, and the values are the new labels.
"""
self.fontname = fontname or "arial" self.fontname = fontname or "arial"
self.labels = labels or LabelsDict(nodes={}, edges={}) self.labels = labels or LabelsDict(nodes={}, edges={})
def get_node_label(self, label: str) -> str: def get_node_label(self, label: str) -> str:
"""Returns the label to use for a node.
Args:
label: The original label
Returns:
The new label.
"""
label = self.labels.get("nodes", {}).get(label, label) label = self.labels.get("nodes", {}).get(label, label)
return f"<<B>{label}</B>>" return f"<<B>{label}</B>>"
def get_edge_label(self, label: str) -> str: def get_edge_label(self, label: str) -> str:
"""Returns the label to use for an edge.
Args:
label: The original label
Returns:
The new label.
"""
label = self.labels.get("edges", {}).get(label, label) label = self.labels.get("edges", {}).get(label, label)
return f"<<U>{label}</U>>" return f"<<U>{label}</U>>"
def add_node(self, viz: Any, node: str) -> None: def add_node(self, viz: Any, node: str) -> None:
"""Adds a node to the graph.
Args:
viz: The graphviz object
node: The node to add
Returns:
None
"""
viz.add_node( viz.add_node(
node, node,
label=self.get_node_label(node), label=self.get_node_label(node),
@ -59,6 +103,18 @@ class PngDrawer:
label: Optional[str] = None, label: Optional[str] = None,
conditional: bool = False, conditional: bool = False,
) -> None: ) -> None:
"""Adds an edge to the graph.
Args:
viz: The graphviz object
source: The source node
target: The target node
label: The label for the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
Returns:
None
"""
viz.add_edge( viz.add_edge(
source, source,
target, target,
@ -69,8 +125,8 @@ class PngDrawer:
) )
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
""" """Draw the given state graph into a PNG file.
Draws the given state graph into a PNG file.
Requires graphviz and pygraphviz to be installed. Requires graphviz and pygraphviz to be installed.
:param graph: The graph to draw :param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned. :param output_path: The path to save the PNG. If None, PNG bytes are returned.

View File

@ -329,8 +329,7 @@ _graph_passthrough: RunnablePassthrough = RunnablePassthrough()
class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
""" """Runnable that assigns key-value pairs to Dict[str, Any] inputs.
A runnable that assigns key-value pairs to Dict[str, Any] inputs.
The `RunnableAssign` class takes input dictionaries and, through a The `RunnableAssign` class takes input dictionaries and, through a
`RunnableParallel` instance, applies transformations, then combines `RunnableParallel` instance, applies transformations, then combines

View File

@ -507,6 +507,15 @@ def create_model(
__model_name: str, __model_name: str,
**field_definitions: Any, **field_definitions: Any,
) -> Type[BaseModel]: ) -> Type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Args:
__model_name: The name of the model.
**field_definitions: The field definitions for the model.
Returns:
Type[BaseModel]: The created model.
"""
try: try:
return _create_model_cached(__model_name, **field_definitions) return _create_model_cached(__model_name, **field_definitions)
except TypeError: except TypeError:

View File

@ -93,11 +93,11 @@ class Comparator(str, Enum):
class FilterDirective(Expr, ABC): class FilterDirective(Expr, ABC):
"""A filtering expression.""" """Filtering expression."""
class Comparison(FilterDirective): class Comparison(FilterDirective):
"""A comparison to a value.""" """Comparison to a value."""
comparator: Comparator comparator: Comparator
attribute: str attribute: str
@ -112,7 +112,7 @@ class Comparison(FilterDirective):
class Operation(FilterDirective): class Operation(FilterDirective):
"""A logical operation over other directives.""" """Llogical operation over other directives."""
operator: Operator operator: Operator
arguments: List[FilterDirective] arguments: List[FilterDirective]
@ -124,7 +124,7 @@ class Operation(FilterDirective):
class StructuredQuery(Expr): class StructuredQuery(Expr):
"""A structured query.""" """Structured query."""
query: str query: str
"""Query string.""" """Query string."""

View File

@ -43,7 +43,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa
def tracing_enabled( def tracing_enabled(
session_name: str = "default", session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]: ) -> Generator[TracerSessionV1, None, None]:
"""Throws an error because this has been replaced by tracing_v2_enabled.""" """Throw an error because this has been replaced by tracing_v2_enabled."""
raise RuntimeError( raise RuntimeError(
"tracing_enabled is no longer supported. Please use tracing_enabled_v2 instead." "tracing_enabled is no longer supported. Please use tracing_enabled_v2 instead."
) )

View File

@ -2,6 +2,7 @@ from typing import Any
def get_headers(*args: Any, **kwargs: Any) -> Any: def get_headers(*args: Any, **kwargs: Any) -> Any:
"""Throw an error because this has been replaced by get_headers."""
raise RuntimeError( raise RuntimeError(
"get_headers for LangChainTracerV1 is no longer supported. " "get_headers for LangChainTracerV1 is no longer supported. "
"Please use LangChainTracer instead." "Please use LangChainTracer instead."
@ -9,6 +10,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any:
def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any:
"""Throw an error because this has been replaced by LangChainTracer."""
raise RuntimeError( raise RuntimeError(
"LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead."
) )

View File

@ -120,7 +120,7 @@ async def tee_peer(
class Tee(Generic[T]): class Tee(Generic[T]):
""" """
Create ``n`` separate asynchronous iterators over ``iterable`` Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order. the same items in the same order.

View File

@ -31,6 +31,8 @@ _LAST_TAG_LINE = None
class ChevronError(SyntaxError): class ChevronError(SyntaxError):
"""Custom exception for Chevron errors."""
pass pass
@ -40,7 +42,7 @@ class ChevronError(SyntaxError):
def grab_literal(template: str, l_del: str) -> Tuple[str, str]: def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template""" """Parse a literal from the template."""
global _CURRENT_LINE global _CURRENT_LINE
@ -57,7 +59,7 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone""" """Do a preliminary check to see if a tag could be a standalone."""
# If there is a newline, or the previous tag was a standalone # If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone: if literal.find("\n") != -1 or is_standalone:
@ -75,7 +77,7 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final checkto see if a tag could be a standalone""" """Do a final check to see if a tag could be a standalone."""
# Check right side if we might be a standalone # Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]: if is_standalone and tag_type not in ["variable", "no escape"]:
@ -93,7 +95,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template""" """Parse a tag from a template."""
global _CURRENT_LINE global _CURRENT_LINE
global _LAST_TAG_LINE global _LAST_TAG_LINE
@ -157,7 +159,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s
def tokenize( def tokenize(
template: str, def_ldel: str = "{{", def_rdel: str = "}}" template: str, def_ldel: str = "{{", def_rdel: str = "}}"
) -> Iterator[Tuple[str, str]]: ) -> Iterator[Tuple[str, str]]:
"""Tokenize a mustache template """Tokenize a mustache template.
Tokenizes a mustache template in a generator fashion, Tokenizes a mustache template in a generator fashion,
using file-like objects. It also accepts a string containing using file-like objects. It also accepts a string containing

View File

@ -84,7 +84,7 @@ def mock_now(dt_value): # type: ignore
def guard_import( def guard_import(
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
) -> Any: ) -> Any:
"""Dynamically imports a module and raises a helpful exception if the module is not """Dynamically import a module and raise an exception if the module is not
installed.""" installed."""
try: try:
module = importlib.import_module(module_name, package) module = importlib.import_module(module_name, package)