From 1a2ff56cd8c2c87aa06b3aba90f8bf597f8a2e23 Mon Sep 17 00:00:00 2001 From: Leonid Ganeline Date: Mon, 29 Apr 2024 12:35:34 -0700 Subject: [PATCH] core[patch[: docstring update (#21036) Added missed docstrings. Updated docstrings to consistent format. --- .../example_selectors/semantic_similarity.py | 27 ++++++-- libs/core/langchain_core/indexing/base.py | 14 ++-- .../language_models/chat_models.py | 2 +- .../language_models/fake_chat_models.py | 4 +- .../langchain_core/language_models/llms.py | 2 +- libs/core/langchain_core/prompt_values.py | 2 + libs/core/langchain_core/prompts/image.py | 2 +- libs/core/langchain_core/prompts/prompt.py | 2 +- .../core/langchain_core/prompts/structured.py | 3 + .../langchain_core/runnables/configurable.py | 3 +- libs/core/langchain_core/runnables/graph.py | 30 +++++++++ .../langchain_core/runnables/graph_png.py | 66 +++++++++++++++++-- .../langchain_core/runnables/passthrough.py | 3 +- libs/core/langchain_core/runnables/utils.py | 9 +++ libs/core/langchain_core/structured_query.py | 8 +-- libs/core/langchain_core/tracers/context.py | 2 +- .../langchain_core/tracers/langchain_v1.py | 2 + libs/core/langchain_core/utils/aiter.py | 2 +- libs/core/langchain_core/utils/mustache.py | 12 ++-- libs/core/langchain_core/utils/utils.py | 2 +- 20 files changed, 159 insertions(+), 38 deletions(-) diff --git a/libs/core/langchain_core/example_selectors/semantic_similarity.py b/libs/core/langchain_core/example_selectors/semantic_similarity.py index 1a6a9044e98..47480789c7f 100644 --- a/libs/core/langchain_core/example_selectors/semantic_similarity.py +++ b/libs/core/langchain_core/example_selectors/semantic_similarity.py @@ -73,10 +73,10 @@ class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): - """Example selector that selects examples based on SemanticSimilarity.""" + """Select examples based on semantic similarity.""" def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on semantic similarity.""" + """Select examples based on semantic similarity.""" # Get the docs with the highest similarity. vectorstore_kwargs = self.vectorstore_kwargs or {} example_docs = self.vectorstore.similarity_search( @@ -87,7 +87,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): return self._documents_to_examples(example_docs) async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on semantic similarity.""" + """Asynchronously select examples based on semantic similarity.""" # Get the docs with the highest similarity. vectorstore_kwargs = self.vectorstore_kwargs or {} example_docs = await self.vectorstore.asimilarity_search( @@ -187,7 +187,7 @@ class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): - """ExampleSelector that selects examples based on Max Marginal Relevance. + """Select examples based on Max Marginal Relevance. This was shown to improve performance in this paper: https://arxiv.org/pdf/2211.13892.pdf @@ -197,6 +197,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): """Number of examples to fetch to rerank.""" def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Select examples based on Max Marginal Relevance. + + Args: + input_variables: The input variables to use for search. + + Returns: + The selected examples. + """ example_docs = self.vectorstore.max_marginal_relevance_search( self._example_to_text(input_variables, self.input_keys), k=self.k, @@ -205,6 +213,14 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): return self._documents_to_examples(example_docs) async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + """Asynchronously select examples based on Max Marginal Relevance. + + Args: + input_variables: The input variables to use for search. + + Returns: + The selected examples. + """ example_docs = await self.vectorstore.amax_marginal_relevance_search( self._example_to_text(input_variables, self.input_keys), k=self.k, @@ -272,7 +288,8 @@ class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: - """Create k-shot example selector using example list and embeddings. + """Asynchronously create k-shot example selector using example list and + embeddings. Reshuffles examples dynamically based on Max Marginal Relevance. diff --git a/libs/core/langchain_core/indexing/base.py b/libs/core/langchain_core/indexing/base.py index c91ffa78fd3..776f1f1089b 100644 --- a/libs/core/langchain_core/indexing/base.py +++ b/libs/core/langchain_core/indexing/base.py @@ -5,7 +5,7 @@ from typing import List, Optional, Sequence class RecordManager(ABC): - """An abstract base class representing the interface for a record manager.""" + """Abstract base class representing the interface for a record manager.""" def __init__( self, @@ -24,7 +24,7 @@ class RecordManager(ABC): @abstractmethod async def acreate_schema(self) -> None: - """Create the database schema for the record manager.""" + """Asynchronously create the database schema for the record manager.""" @abstractmethod def get_time(self) -> float: @@ -39,7 +39,7 @@ class RecordManager(ABC): @abstractmethod async def aget_time(self) -> float: - """Get the current server time as a high resolution timestamp! + """Asynchronously get the current server time as a high resolution timestamp. It's important to get this from the server to ensure a monotonic clock, otherwise there may be data loss when cleaning up old documents! @@ -84,7 +84,7 @@ class RecordManager(ABC): group_ids: Optional[Sequence[Optional[str]]] = None, time_at_least: Optional[float] = None, ) -> None: - """Upsert records into the database. + """Asynchronously upsert records into the database. Args: keys: A list of record keys to upsert. @@ -117,7 +117,7 @@ class RecordManager(ABC): @abstractmethod async def aexists(self, keys: Sequence[str]) -> List[bool]: - """Check if the provided keys exist in the database. + """Asynchronously check if the provided keys exist in the database. Args: keys: A list of keys to check. @@ -156,7 +156,7 @@ class RecordManager(ABC): group_ids: Optional[Sequence[str]] = None, limit: Optional[int] = None, ) -> List[str]: - """List records in the database based on the provided filters. + """Asynchronously list records in the database based on the provided filters. Args: before: Filter to list records updated before this time. @@ -178,7 +178,7 @@ class RecordManager(ABC): @abstractmethod async def adelete_keys(self, keys: Sequence[str]) -> None: - """Delete specified records from the database. + """Asynchronously delete specified records from the database. Args: keys: A list of keys to delete. diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f24c52e4311..0bf73646037 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -913,7 +913,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class SimpleChatModel(BaseChatModel): - """A simplified implementation for a chat model to inherit from.""" + """Simplified implementation for a chat model to inherit from.""" def _generate( self, diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 8285d79674d..4a4069b2db1 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -151,7 +151,7 @@ class FakeChatModel(SimpleChatModel): class GenericFakeChatModel(BaseChatModel): - """A generic fake chat model that can be used to test the chat model interface. + """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests * Invokes on_llm_new_token to allow for testing of callback related code for new @@ -288,7 +288,7 @@ class GenericFakeChatModel(BaseChatModel): class ParrotFakeChatModel(BaseChatModel): - """A generic fake chat model that can be used to test the chat model interface. + """Generic fake chat model that can be used to test the chat model interface. * Chat model should be usable in both sync and async tests """ diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index e307e8035c3..7b0f8dfd390 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1218,7 +1218,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): class LLM(BaseLLM): - """This class exposes a simple interface for implementing a custom LLM. + """Simple interface for implementing a custom LLM. You should subclass this class and implement the following: diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index 37957daa327..a3ad813cd51 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -90,6 +90,8 @@ class ChatPromptValue(PromptValue): class ImageURL(TypedDict, total=False): + """Image URL.""" + detail: Literal["auto", "low", "high"] """Specifies the detail level of the image.""" diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index d4f47779a3a..09d63db65db 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -8,7 +8,7 @@ from langchain_core.utils import image as image_utils class ImagePromptTemplate(BasePromptTemplate[ImageURL]): - """An image prompt template for a multimodal model.""" + """Image prompt template for a multimodal model.""" template: dict = Field(default_factory=dict) """Template for the prompt.""" diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index e909ee9088d..1dec3cb0f23 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -17,7 +17,7 @@ from langchain_core.runnables.config import RunnableConfig class PromptTemplate(StringPromptTemplate): - """A prompt template for a language model. + """Prompt template for a language model. A prompt template consists of a string template. It accepts a set of parameters from the user that can be used to generate a prompt for a language model. diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 882f34cc827..2f6a48c3016 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -33,7 +33,10 @@ from langchain_core.runnables.base import ( @beta() class StructuredPrompt(ChatPromptTemplate): + """Structured prompt template for a language model.""" + schema_: Union[Dict, Type[BaseModel]] + """Schema for the structured prompt.""" @classmethod def get_lc_namespace(cls) -> List[str]: diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 410cd976f9d..d8b5e0a119b 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -641,7 +641,8 @@ def make_options_spec( description: Optional[str], ) -> ConfigurableFieldSpec: """Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or - ConfigurableFieldMultiOption.""" + ConfigurableFieldMultiOption. + """ with _enums_for_spec_lock: if enum := _enums_for_spec.get(spec): pass diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index 4486fca5252..4f35205e853 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -26,11 +26,23 @@ if TYPE_CHECKING: class LabelsDict(TypedDict): + """Dictionary of labels for nodes and edges in a graph.""" + nodes: dict[str, str] + """Labels for nodes.""" edges: dict[str, str] + """Labels for edges.""" def is_uuid(value: str) -> bool: + """Check if a string is a valid UUID. + + Args: + value: The string to check. + + Returns: + True if the string is a valid UUID, False otherwise. + """ try: UUID(value) return True @@ -95,6 +107,14 @@ class MermaidDrawMethod(Enum): def node_data_str(node: Node) -> str: + """Convert the data of a node to a string. + + Args: + node: The node to convert. + + Returns: + A string representation of the data. + """ from langchain_core.runnables.base import Runnable if not is_uuid(node.id): @@ -120,6 +140,16 @@ def node_data_str(node: Node) -> str: def node_data_json( node: Node, *, with_schemas: bool = False ) -> Dict[str, Union[str, Dict[str, Any]]]: + """Convert the data of a node to a JSON-serializable format. + + Args: + node: The node to convert. + with_schemas: Whether to include the schema of the data if + it is a Pydantic model. + + Returns: + A dictionary with the type of the data and the data itself. + """ from langchain_core.load.serializable import to_json_not_implemented from langchain_core.runnables.base import Runnable, RunnableSerializable diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 9116fc81f9b..75983ad279c 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -4,9 +4,9 @@ from langchain_core.runnables.graph import Graph, LabelsDict class PngDrawer: - """ - A helper class to draw a state graph into a PNG file. - Requires graphviz and pygraphviz to be installed. + """Helper class to draw a state graph into a PNG file. + + It requires graphviz and pygraphviz to be installed. :param fontname: The font to use for the labels :param labels: A dictionary of label overrides. The dictionary should have the following format: @@ -30,18 +30,62 @@ class PngDrawer: def __init__( self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None ) -> None: + """Initializes the PNG drawer. + + Args: + fontname: The font to use for the labels + labels: A dictionary of label overrides. The dictionary + should have the following format: + { + "nodes": { + "node1": "CustomLabel1", + "node2": "CustomLabel2", + "__end__": "End Node" + }, + "edges": { + "continue": "ContinueLabel", + "end": "EndLabel" + } + } + The keys are the original labels, and the values are the new labels. + """ self.fontname = fontname or "arial" self.labels = labels or LabelsDict(nodes={}, edges={}) def get_node_label(self, label: str) -> str: + """Returns the label to use for a node. + + Args: + label: The original label + + Returns: + The new label. + """ label = self.labels.get("nodes", {}).get(label, label) return f"<{label}>" def get_edge_label(self, label: str) -> str: + """Returns the label to use for an edge. + + Args: + label: The original label + + Returns: + The new label. + """ label = self.labels.get("edges", {}).get(label, label) return f"<{label}>" def add_node(self, viz: Any, node: str) -> None: + """Adds a node to the graph. + + Args: + viz: The graphviz object + node: The node to add + + Returns: + None + """ viz.add_node( node, label=self.get_node_label(node), @@ -59,6 +103,18 @@ class PngDrawer: label: Optional[str] = None, conditional: bool = False, ) -> None: + """Adds an edge to the graph. + + Args: + viz: The graphviz object + source: The source node + target: The target node + label: The label for the edge. Defaults to None. + conditional: Whether the edge is conditional. Defaults to False. + + Returns: + None + """ viz.add_edge( source, target, @@ -69,8 +125,8 @@ class PngDrawer: ) def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]: - """ - Draws the given state graph into a PNG file. + """Draw the given state graph into a PNG file. + Requires graphviz and pygraphviz to be installed. :param graph: The graph to draw :param output_path: The path to save the PNG. If None, PNG bytes are returned. diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index ec081aea97f..f761e12ed22 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -329,8 +329,7 @@ _graph_passthrough: RunnablePassthrough = RunnablePassthrough() class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): - """ - A runnable that assigns key-value pairs to Dict[str, Any] inputs. + """Runnable that assigns key-value pairs to Dict[str, Any] inputs. The `RunnableAssign` class takes input dictionaries and, through a `RunnableParallel` instance, applies transformations, then combines diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 9e960fb4c7f..f77f756e666 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -507,6 +507,15 @@ def create_model( __model_name: str, **field_definitions: Any, ) -> Type[BaseModel]: + """Create a pydantic model with the given field definitions. + + Args: + __model_name: The name of the model. + **field_definitions: The field definitions for the model. + + Returns: + Type[BaseModel]: The created model. + """ try: return _create_model_cached(__model_name, **field_definitions) except TypeError: diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index eb61cea6b50..a47bb479a86 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -93,11 +93,11 @@ class Comparator(str, Enum): class FilterDirective(Expr, ABC): - """A filtering expression.""" + """Filtering expression.""" class Comparison(FilterDirective): - """A comparison to a value.""" + """Comparison to a value.""" comparator: Comparator attribute: str @@ -112,7 +112,7 @@ class Comparison(FilterDirective): class Operation(FilterDirective): - """A logical operation over other directives.""" + """Llogical operation over other directives.""" operator: Operator arguments: List[FilterDirective] @@ -124,7 +124,7 @@ class Operation(FilterDirective): class StructuredQuery(Expr): - """A structured query.""" + """Structured query.""" query: str """Query string.""" diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 1db9530016f..caadcfc6856 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -43,7 +43,7 @@ run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVa def tracing_enabled( session_name: str = "default", ) -> Generator[TracerSessionV1, None, None]: - """Throws an error because this has been replaced by tracing_v2_enabled.""" + """Throw an error because this has been replaced by tracing_v2_enabled.""" raise RuntimeError( "tracing_enabled is no longer supported. Please use tracing_enabled_v2 instead." ) diff --git a/libs/core/langchain_core/tracers/langchain_v1.py b/libs/core/langchain_core/tracers/langchain_v1.py index aac99a72061..bf1237d66ab 100644 --- a/libs/core/langchain_core/tracers/langchain_v1.py +++ b/libs/core/langchain_core/tracers/langchain_v1.py @@ -2,6 +2,7 @@ from typing import Any def get_headers(*args: Any, **kwargs: Any) -> Any: + """Throw an error because this has been replaced by get_headers.""" raise RuntimeError( "get_headers for LangChainTracerV1 is no longer supported. " "Please use LangChainTracer instead." @@ -9,6 +10,7 @@ def get_headers(*args: Any, **kwargs: Any) -> Any: def LangChainTracerV1(*args: Any, **kwargs: Any) -> Any: + """Throw an error because this has been replaced by LangChainTracer.""" raise RuntimeError( "LangChainTracerV1 is no longer supported. Please use LangChainTracer instead." ) diff --git a/libs/core/langchain_core/utils/aiter.py b/libs/core/langchain_core/utils/aiter.py index ca44dee3958..837b5473849 100644 --- a/libs/core/langchain_core/utils/aiter.py +++ b/libs/core/langchain_core/utils/aiter.py @@ -120,7 +120,7 @@ async def tee_peer( class Tee(Generic[T]): """ - Create ``n`` separate asynchronous iterators over ``iterable`` + Create ``n`` separate asynchronous iterators over ``iterable``. This splits a single ``iterable`` into multiple iterators, each providing the same items in the same order. diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index 06ea9cd002f..06258375f5f 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -31,6 +31,8 @@ _LAST_TAG_LINE = None class ChevronError(SyntaxError): + """Custom exception for Chevron errors.""" + pass @@ -40,7 +42,7 @@ class ChevronError(SyntaxError): def grab_literal(template: str, l_del: str) -> Tuple[str, str]: - """Parse a literal from the template""" + """Parse a literal from the template.""" global _CURRENT_LINE @@ -57,7 +59,7 @@ def grab_literal(template: str, l_del: str) -> Tuple[str, str]: def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: - """Do a preliminary check to see if a tag could be a standalone""" + """Do a preliminary check to see if a tag could be a standalone.""" # If there is a newline, or the previous tag was a standalone if literal.find("\n") != -1 or is_standalone: @@ -75,7 +77,7 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: - """Do a final checkto see if a tag could be a standalone""" + """Do a final check to see if a tag could be a standalone.""" # Check right side if we might be a standalone if is_standalone and tag_type not in ["variable", "no escape"]: @@ -93,7 +95,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: - """Parse a tag from a template""" + """Parse a tag from a template.""" global _CURRENT_LINE global _LAST_TAG_LINE @@ -157,7 +159,7 @@ def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], s def tokenize( template: str, def_ldel: str = "{{", def_rdel: str = "}}" ) -> Iterator[Tuple[str, str]]: - """Tokenize a mustache template + """Tokenize a mustache template. Tokenizes a mustache template in a generator fashion, using file-like objects. It also accepts a string containing diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 9b63ddf3ea6..0738b1ddb5f 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -84,7 +84,7 @@ def mock_now(dt_value): # type: ignore def guard_import( module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None ) -> Any: - """Dynamically imports a module and raises a helpful exception if the module is not + """Dynamically import a module and raise an exception if the module is not installed.""" try: module = importlib.import_module(module_name, package)