core: Auto-fix some docstrings (#29337)

This commit is contained in:
Christophe Bornet 2025-01-21 19:29:53 +01:00 committed by GitHub
parent 86a0720310
commit 1c4ce7b42b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 364 additions and 457 deletions

View File

@ -50,7 +50,7 @@ def beta(
``@beta`` would mess up ``__init__`` inheritance when installing its ``@beta`` would mess up ``__init__`` inheritance when installing its
own (annotation-emitting) ``C.__init__``). own (annotation-emitting) ``C.__init__``).
Arguments: Args:
message : str, optional message : str, optional
Override the default beta message. The %(since)s, Override the default beta message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s, %(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
@ -63,8 +63,7 @@ def beta(
addendum : str, optional addendum : str, optional
Additional text appended directly to the final message. Additional text appended directly to the final message.
Examples Examples:
--------
.. code-block:: python .. code-block:: python

View File

@ -95,7 +95,7 @@ def deprecated(
defaults to 'class' if decorating a class, 'attribute' if decorating a defaults to 'class' if decorating a class, 'attribute' if decorating a
property, and 'function' otherwise. property, and 'function' otherwise.
Arguments: Args:
since : str since : str
The release at which this API became deprecated. The release at which this API became deprecated.
message : str, optional message : str, optional
@ -122,8 +122,7 @@ def deprecated(
since. Set to other Falsy values to not schedule a removal since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending. date. Cannot be used together with pending.
Examples Examples:
--------
.. code-block:: python .. code-block:: python
@ -183,7 +182,6 @@ def deprecated(
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any: async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Same as warning_emitting_wrapper, but for async functions.""" """Same as warning_emitting_wrapper, but for async functions."""
nonlocal warned nonlocal warned
if not warned and not is_caller_internal(): if not warned and not is_caller_internal():
warned = True warned = True

View File

@ -74,7 +74,8 @@ class AgentAction(Serializable):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"].""" Default is ["langchain", "schema", "agent"].
"""
return ["langchain", "schema", "agent"] return ["langchain", "schema", "agent"]
@property @property
@ -189,7 +190,6 @@ def _convert_agent_observation_to_messages(
Returns: Returns:
AIMessage that corresponds to the original tool invocation. AIMessage that corresponds to the original tool invocation.
""" """
if isinstance(agent_action, AgentActionMessageLog): if isinstance(agent_action, AgentActionMessageLog):
return [_create_function_message(agent_action, observation)] return [_create_function_message(agent_action, observation)]
else: else:

View File

@ -307,8 +307,7 @@ class ContextSet(RunnableSerializable):
class Context: class Context:
""" """Context for a runnable.
Context for a runnable.
The `Context` class provides methods for creating context scopes, The `Context` class provides methods for creating context scopes,
getters, and setters within a runnable. It allows for managing getters, and setters within a runnable. It allows for managing

View File

@ -131,7 +131,8 @@ class ChainManagerMixin:
outputs (Dict[str, Any]): The outputs of the chain. outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
def on_chain_error( def on_chain_error(
self, self,
@ -147,7 +148,8 @@ class ChainManagerMixin:
error (BaseException): The error that occurred. error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
def on_agent_action( def on_agent_action(
self, self,
@ -163,7 +165,8 @@ class ChainManagerMixin:
action (AgentAction): The agent action. action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
def on_agent_finish( def on_agent_finish(
self, self,
@ -179,7 +182,8 @@ class ChainManagerMixin:
finish (AgentFinish): The agent finish. finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
class ToolManagerMixin: class ToolManagerMixin:
@ -199,7 +203,8 @@ class ToolManagerMixin:
output (Any): The output of the tool. output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
def on_tool_error( def on_tool_error(
self, self,
@ -215,7 +220,8 @@ class ToolManagerMixin:
error (BaseException): The error that occurred. error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
class CallbackManagerMixin: class CallbackManagerMixin:
@ -824,7 +830,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id (UUID): The run ID. This is the ID of the current run. run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run. parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags. tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments.""" kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_error( async def on_retriever_error(
self, self,

View File

@ -164,6 +164,7 @@ async def atrace_as_chain_group(
Defaults to None. Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs. metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None. Defaults to None.
Returns: Returns:
AsyncCallbackManager: The async callback manager for the chain group. AsyncCallbackManager: The async callback manager for the chain group.
@ -216,8 +217,7 @@ Func = TypeVar("Func", bound=Callable)
def shielded(func: Func) -> Func: def shielded(func: Func) -> Func:
""" """Makes so an awaitable method is always shielded from cancellation.
Makes so an awaitable method is always shielded from cancellation.
Args: Args:
func (Callable): The function to shield. func (Callable): The function to shield.
@ -1310,7 +1310,6 @@ class CallbackManager(BaseCallbackManager):
List[CallbackManagerForLLMRun]: A callback manager for each List[CallbackManagerForLLMRun]: A callback manager for each
list of messages as an LLM run. list of messages as an LLM run.
""" """
managers = [] managers = []
for message_list in messages: for message_list in messages:
if run_id is not None: if run_id is not None:
@ -1729,7 +1728,6 @@ class AsyncCallbackManager(BaseCallbackManager):
callback managers, one for each LLM Run corresponding callback managers, one for each LLM Run corresponding
to each prompt. to each prompt.
""" """
inline_tasks = [] inline_tasks = []
non_inline_tasks = [] non_inline_tasks = []
inline_handlers = [handler for handler in self.handlers if handler.run_inline] inline_handlers = [handler for handler in self.handlers if handler.run_inline]

View File

@ -1,6 +1,5 @@
"""**Chat message history** stores a history of the message interactions in a chat. """**Chat message history** stores a history of the message interactions in a chat.
**Class hierarchy:** **Class hierarchy:**
.. code-block:: .. code-block::
@ -187,10 +186,10 @@ class BaseChatMessageHistory(ABC):
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:
"""Remove all messages from the store""" """Remove all messages from the store."""
async def aclear(self) -> None: async def aclear(self) -> None:
"""Async remove all messages from the store""" """Async remove all messages from the store."""
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
await run_in_executor(None, self.clear) await run_in_executor(None, self.clear)

View File

@ -8,7 +8,8 @@ from langchain_core.messages import BaseMessage
class ChatSession(TypedDict, total=False): class ChatSession(TypedDict, total=False):
"""Chat Session represents a single """Chat Session represents a single
conversation, channel, or other group of messages.""" conversation, channel, or other group of messages.
"""
messages: Sequence[BaseMessage] messages: Sequence[BaseMessage]
"""A sequence of the LangChain chat messages loaded from the source.""" """A sequence of the LangChain chat messages loaded from the source."""

View File

@ -48,7 +48,6 @@ class BaseLoader(ABC): # noqa: B024
Returns: Returns:
List of Documents. List of Documents.
""" """
if text_splitter is None: if text_splitter is None:
try: try:
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter

View File

@ -15,15 +15,16 @@ class BaseExampleSelector(ABC):
Args: Args:
example: A dictionary with keys as input variables example: A dictionary with keys as input variables
and values as their values.""" and values as their values.
"""
async def aadd_example(self, example: dict[str, str]) -> Any: async def aadd_example(self, example: dict[str, str]) -> Any:
"""Async add new example to store. """Async add new example to store.
Args: Args:
example: A dictionary with keys as input variables example: A dictionary with keys as input variables
and values as their values.""" and values as their values.
"""
return await run_in_executor(None, self.add_example, example) return await run_in_executor(None, self.add_example, example)
@abstractmethod @abstractmethod
@ -32,13 +33,14 @@ class BaseExampleSelector(ABC):
Args: Args:
input_variables: A dictionary with keys as input variables input_variables: A dictionary with keys as input variables
and values as their values.""" and values as their values.
"""
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the inputs. """Async select which examples to use based on the inputs.
Args: Args:
input_variables: A dictionary with keys as input variables input_variables: A dictionary with keys as input variables
and values as their values.""" and values as their values.
"""
return await run_in_executor(None, self.select_examples, input_variables) return await run_in_executor(None, self.select_examples, input_variables)

View File

@ -50,7 +50,6 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
example: A dictionary with keys as input variables example: A dictionary with keys as input variables
and values as their values. and values as their values.
""" """
self.add_example(example) self.add_example(example)
@model_validator(mode="after") @model_validator(mode="after")

View File

@ -241,7 +241,7 @@ def index(
For the time being, documents are indexed using their hashes, and users For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document. are not able to specify the uid of the document.
IMPORTANT: Important:
* In full mode, the loader should be returning * In full mode, the loader should be returning
the entire dataset, and not just a subset of the dataset. the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not Otherwise, the auto_cleanup will remove documents that it is not
@ -546,7 +546,7 @@ async def aindex(
For the time being, documents are indexed using their hashes, and users For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document. are not able to specify the uid of the document.
IMPORTANT: Important:
* In full mode, the loader should be returning * In full mode, the loader should be returning
the entire dataset, and not just a subset of the dataset. the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not Otherwise, the auto_cleanup will remove documents that it is not
@ -614,7 +614,6 @@ async def aindex(
* Added `scoped_full` cleanup mode. * Added `scoped_full` cleanup mode.
""" """
if cleanup not in {"incremental", "full", "scoped_full", None}: if cleanup not in {"incremental", "full", "scoped_full", None}:
msg = ( msg = (
f"cleanup should be one of 'incremental', 'full', 'scoped_full' or None. " f"cleanup should be one of 'incremental', 'full', 'scoped_full' or None. "

View File

@ -288,7 +288,6 @@ class InMemoryRecordManager(RecordManager):
ids. ids.
ValueError: If time_at_least is in the future. ValueError: If time_at_least is in the future.
""" """
if group_ids and len(keys) != len(group_ids): if group_ids and len(keys) != len(group_ids):
msg = "Length of keys must match length of group_ids" msg = "Length of keys must match length of group_ids"
raise ValueError(msg) raise ValueError(msg)

View File

@ -84,7 +84,6 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
Returns: Returns:
ChatResult: Chat result. ChatResult: Chat result.
""" """
generation = next(stream, None) generation = next(stream, None)
if generation: if generation:
generation += list(stream) generation += list(stream)
@ -112,7 +111,6 @@ async def agenerate_from_stream(
Returns: Returns:
ChatResult: Chat result. ChatResult: Chat result.
""" """
chunks = [chunk async for chunk in stream] chunks = [chunk async for chunk in stream]
return await run_in_executor(None, generate_from_stream, iter(chunks)) return await run_in_executor(None, generate_from_stream, iter(chunks))
@ -521,7 +519,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any, **kwargs: Any,
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
# get default provider from class name # get default provider from class name
default_provider = self.__class__.__name__ default_provider = self.__class__.__name__
if default_provider.startswith("Chat"): if default_provider.startswith("Chat"):
@ -955,7 +952,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
async def _agenerate( async def _agenerate(
self, self,
@ -964,7 +961,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
return await run_in_executor( return await run_in_executor(
None, None,
self._generate, self._generate,

View File

@ -42,7 +42,7 @@ class FakeListLLM(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response."""
response = self.responses[self.i] response = self.responses[self.i]
if self.i < len(self.responses) - 1: if self.i < len(self.responses) - 1:
self.i += 1 self.i += 1
@ -57,7 +57,7 @@ class FakeListLLM(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response."""
response = self.responses[self.i] response = self.responses[self.i]
if self.i < len(self.responses) - 1: if self.i < len(self.responses) - 1:
self.i += 1 self.i += 1

View File

@ -220,7 +220,7 @@ class GenericFakeChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
message = next(self.messages) message = next(self.messages)
message_ = AIMessage(content=message) if isinstance(message, str) else message message_ = AIMessage(content=message) if isinstance(message, str) else message
generation = ChatGeneration(message=message_) generation = ChatGeneration(message=message_)
@ -342,7 +342,7 @@ class ParrotFakeChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
return ChatResult(generations=[ChatGeneration(message=messages[-1])]) return ChatResult(generations=[ChatGeneration(message=messages[-1])])
@property @property

View File

@ -91,7 +91,6 @@ def create_base_retry_decorator(
Raises: Raises:
ValueError: If the cache is not set and cache is True. ValueError: If the cache is not set and cache is True.
""" """
_logging = before_sleep_log(logger, logging.WARNING) _logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None: def _before_sleep(retry_state: RetryCallState) -> None:
@ -278,7 +277,6 @@ async def aupdate_cache(
Raises: Raises:
ValueError: If the cache is not set and cache is True. ValueError: If the cache is not set and cache is True.
""" """
llm_cache = _resolve_cache(cache) llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result existing_prompts[missing_prompt_idxs[i]] = result
@ -292,7 +290,8 @@ async def aupdate_cache(
class BaseLLM(BaseLanguageModel[str], ABC): class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface. """Base LLM abstract interface.
It should take in a prompt and return a string.""" It should take in a prompt and return a string.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]""" """[DEPRECATED]"""
@ -346,7 +345,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> LangSmithParams: ) -> LangSmithParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
# get default provider from class name # get default provider from class name
default_provider = self.__class__.__name__ default_provider = self.__class__.__name__
if default_provider.endswith("LLM"): if default_provider.endswith("LLM"):

View File

@ -1,5 +1,4 @@
""" """This file contains a mapping between the lc_namespace path for a given
This file contains a mapping between the lc_namespace path for a given
subclass that implements from Serializable to the namespace subclass that implements from Serializable to the namespace
where that class is actually located. where that class is actually located.

View File

@ -28,7 +28,8 @@ class FunctionMessage(BaseMessage):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@ -48,7 +49,8 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore

View File

@ -41,7 +41,8 @@ class HumanMessage(BaseMessage):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def __init__( def __init__(
@ -72,5 +73,6 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -28,7 +28,8 @@ class RemoveMessage(BaseMessage):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -35,7 +35,8 @@ class SystemMessage(BaseMessage):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
def __init__( def __init__(
@ -66,5 +67,6 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]

View File

@ -89,7 +89,8 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"].""" Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"] return ["langchain", "schema", "messages"]
@model_validator(mode="before") @model_validator(mode="before")

View File

@ -817,7 +817,6 @@ def trim_messages(
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"), AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
] ]
""" # noqa: E501 """ # noqa: E501
if start_on and strategy == "first": if start_on and strategy == "first":
raise ValueError raise ValueError
if include_system and strategy == "first": if include_system and strategy == "first":

View File

@ -46,8 +46,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
Args: Args:
text: The output of an LLM call. text: The output of an LLM call.
Returns: Returns:
A list of strings. A list of strings.
""" """
def parse_iter(self, text: str) -> Iterator[re.Match]: def parse_iter(self, text: str) -> Iterator[re.Match]:
@ -135,7 +135,9 @@ class CommaSeparatedListOutputParser(ListOutputParser):
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Check if the langchain object is serializable. """Check if the langchain object is serializable.
Returns True."""
Returns True.
"""
return True return True
@classmethod @classmethod

View File

@ -84,7 +84,6 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
if len(result) != 1: if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}" msg = f"Expected exactly one result, but got {len(result)}"
raise OutputParserException(msg) raise OutputParserException(msg)
@ -189,7 +188,6 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
the provided schema. the provided schema.
Example: Example:
... code-block:: python ... code-block:: python
message = AIMessage( message = AIMessage(

View File

@ -168,7 +168,6 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Raises: Raises:
OutputParserException: If the output is not valid JSON. OutputParserException: If the output is not valid JSON.
""" """
generation = result[0] generation = result[0]
if not isinstance(generation, ChatGeneration): if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation." msg = "This output parser can only be used with a chat generation."

View File

@ -129,7 +129,8 @@ class ImagePromptValue(PromptValue):
class ChatPromptValueConcrete(ChatPromptValue): class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts. """Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas.""" For use in external schemas.
"""
messages: Sequence[AnyMessage] messages: Sequence[AnyMessage]
"""Sequence of messages.""" """Sequence of messages."""

View File

@ -98,13 +98,15 @@ class BasePromptTemplate(
@classmethod @classmethod
def get_lc_namespace(cls) -> list[str]: def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object. """Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"].""" Returns ["langchain", "schema", "prompt_template"].
"""
return ["langchain", "schema", "prompt_template"] return ["langchain", "schema", "prompt_template"]
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable. """Return whether this class is serializable.
Returns True.""" Returns True.
"""
return True return True
model_config = ConfigDict( model_config = ConfigDict(

View File

@ -54,7 +54,8 @@ class BaseMessagePromptTemplate(Serializable, ABC):
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable. """Return whether or not the class is serializable.
Returns: True""" Returns: True.
"""
return True return True
@classmethod @classmethod
@ -392,8 +393,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@property @property
def input_variables(self) -> list[str]: def input_variables(self) -> list[str]:
""" """Input variables for this prompt template.
Input variables for this prompt template.
Returns: Returns:
List of input variable names. List of input variable names.
@ -624,8 +624,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@property @property
def input_variables(self) -> list[str]: def input_variables(self) -> list[str]:
""" """Input variables for this prompt template.
Input variables for this prompt template.
Returns: Returns:
List of input variable names. List of input variable names.
@ -742,8 +741,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
@property @property
def lc_attributes(self) -> dict: def lc_attributes(self) -> dict:
""" """Return a list of attribute names that should be included in the
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the serialized kwargs. These attributes must be accepted by the
constructor. constructor.
""" """
@ -980,7 +978,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
A chat prompt template. A chat prompt template.
Examples: Examples:
Instantiation from a list of message templates: Instantiation from a list of message templates:
.. code-block:: python .. code-block:: python
@ -1173,7 +1170,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
Examples: Examples:
Instantiation from a list of message templates: Instantiation from a list of message templates:
.. code-block:: python .. code-block:: python

View File

@ -265,7 +265,6 @@ class FewShotChatMessagePromptTemplate(
to dynamically select examples based on the input. to dynamically select examples based on the input.
Examples: Examples:
Prompt template with a fixed list of examples (matching the sample Prompt template with a fixed list of examples (matching the sample
conversation above): conversation above):

View File

@ -184,8 +184,7 @@ def _load_prompt_from_file(
def _load_chat_prompt(config: dict) -> ChatPromptTemplate: def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load chat prompt from config""" """Load chat prompt from config."""
messages = config.pop("messages") messages = config.pop("messages")
template = messages[0]["prompt"].pop("template") if messages else None template = messages[0]["prompt"].pop("template") if messages else None
config.pop("input_variables") config.pop("input_variables")

View File

@ -23,8 +23,7 @@ def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
), ),
) )
class PipelinePromptTemplate(BasePromptTemplate): class PipelinePromptTemplate(BasePromptTemplate):
""" """This has been deprecated in favor of chaining individual prompts together in your
This has been deprecated in favor of chaining individual prompts together in your
code. E.g. using a for loop, you could do: code. E.g. using a for loop, you could do:
.. code-block:: python .. code-block:: python

View File

@ -284,7 +284,6 @@ class PromptTemplate(StringPromptTemplate):
Returns: Returns:
The prompt template loaded from the template. The prompt template loaded from the template.
""" """
input_variables = get_template_variables(template, template_format) input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {} _partial_variables = partial_variables or {}

View File

@ -64,8 +64,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
def validate_jinja2(template: str, input_variables: list[str]) -> None: def validate_jinja2(template: str, input_variables: list[str]) -> None:
""" """Validate that the input variables are valid for the template.
Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found. Issues a warning if missing or extra variables are found.
Args: Args:

View File

@ -72,7 +72,6 @@ class StructuredPrompt(ChatPromptTemplate):
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
Examples: Examples:
Instantiation from a list of message templates: Instantiation from a list of message templates:
.. code-block:: python .. code-block:: python

View File

@ -199,7 +199,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
default_retriever_name = self.get_name() default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"): if default_retriever_name.startswith("Retriever"):
default_retriever_name = default_retriever_name[9:] default_retriever_name = default_retriever_name[9:]
@ -342,6 +341,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Args: Args:
query: String to find relevant documents for. query: String to find relevant documents for.
run_manager: The callback handler to use. run_manager: The callback handler to use.
Returns: Returns:
List of relevant documents. List of relevant documents.
""" """

View File

@ -483,7 +483,6 @@ class Runnable(Generic[Input, Output], ABC):
Returns: Returns:
A pydantic model that can be used to validate config. A pydantic model that can be used to validate config.
""" """
include = include or [] include = include or []
config_specs = self.config_specs config_specs = self.config_specs
configurable = ( configurable = (
@ -817,8 +816,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[tuple[int, Union[Output, Exception]]]: ) -> Iterator[tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs, """Run invoke in parallel on a list of inputs,
yielding results as they complete.""" yielding results as they complete.
"""
if not inputs: if not inputs:
return return
@ -949,7 +948,6 @@ class Runnable(Generic[Input, Output], ABC):
Yields: Yields:
A tuple of the index of the input and the output from the Runnable. A tuple of the index of the input and the output from the Runnable.
""" """
if not inputs: if not inputs:
return return
@ -981,8 +979,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
""" """Default implementation of stream, which calls invoke.
Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output. Subclasses should override this method if they support streaming output.
Args: Args:
@ -1001,8 +998,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
""" """Default implementation of astream, which calls ainvoke.
Default implementation of astream, which calls ainvoke.
Subclasses should override this method if they support streaming output. Subclasses should override this method if they support streaming output.
Args: Args:
@ -1064,8 +1060,7 @@ class Runnable(Generic[Input, Output], ABC):
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]: ) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
""" """Stream all output from a Runnable, as reported to the callback system.
Stream all output from a Runnable, as reported to the callback system.
This includes all inner runs of LLMs, Retrievers, Tools, etc. This includes all inner runs of LLMs, Retrievers, Tools, etc.
Output is streamed as Log objects, which include a list of Output is streamed as Log objects, which include a list of
@ -1392,8 +1387,8 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
""" """Default implementation of transform, which buffers input and calls astream.
Default implementation of transform, which buffers input and then calls stream.
Subclasses should override this method if they can start producing output while Subclasses should override this method if they can start producing output while
input is still being generated. input is still being generated.
@ -1434,8 +1429,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
""" """Default implementation of atransform, which buffers input and calls astream.
Default implementation of atransform, which buffers input and calls astream.
Subclasses should override this method if they can start producing output while Subclasses should override this method if they can start producing output while
input is still being generated. input is still being generated.
@ -1472,8 +1466,7 @@ class Runnable(Generic[Input, Output], ABC):
yield output yield output
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
""" """Bind arguments to a Runnable, returning a new Runnable.
Bind arguments to a Runnable, returning a new Runnable.
Useful when a Runnable in a chain requires an argument that is not Useful when a Runnable in a chain requires an argument that is not
in the output of the previous Runnable or included in the user input. in the output of the previous Runnable or included in the user input.
@ -1520,8 +1513,7 @@ class Runnable(Generic[Input, Output], ABC):
# Sadly Unpack is not well-supported by mypy so this will have to be untyped # Sadly Unpack is not well-supported by mypy so this will have to be untyped
**kwargs: Any, **kwargs: Any,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """Bind config to a Runnable, returning a new Runnable.
Bind config to a Runnable, returning a new Runnable.
Args: Args:
config: The config to bind to the Runnable. config: The config to bind to the Runnable.
@ -1552,8 +1544,7 @@ class Runnable(Generic[Input, Output], ABC):
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]] Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None, ] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """Bind lifecycle listeners to a Runnable, returning a new Runnable.
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the Runnable starts running, with the Run object. on_start: Called before the Runnable starts running, with the Run object.
on_end: Called after the Runnable finishes running, with the Run object. on_end: Called after the Runnable finishes running, with the Run object.
@ -1620,8 +1611,7 @@ class Runnable(Generic[Input, Output], ABC):
on_end: Optional[AsyncListener] = None, on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None, on_error: Optional[AsyncListener] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """Bind async lifecycle listeners to a Runnable, returning a new Runnable.
Bind asynchronous lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Asynchronously called before the Runnable starts running. on_start: Asynchronously called before the Runnable starts running.
on_end: Asynchronously called after the Runnable finishes running. on_end: Asynchronously called after the Runnable finishes running.
@ -1711,8 +1701,7 @@ class Runnable(Generic[Input, Output], ABC):
input_type: Optional[type[Input]] = None, input_type: Optional[type[Input]] = None,
output_type: Optional[type[Output]] = None, output_type: Optional[type[Output]] = None,
) -> Runnable[Input, Output]: ) -> Runnable[Input, Output]:
""" """Bind input and output types to a Runnable, returning a new Runnable.
Bind input and output types to a Runnable, returning a new Runnable.
Args: Args:
input_type: The input type to bind to the Runnable. Defaults to None. input_type: The input type to bind to the Runnable. Defaults to None.
@ -1799,8 +1788,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
def map(self) -> Runnable[list[Input], list[Output]]: def map(self) -> Runnable[list[Input], list[Output]]:
""" """Return a new Runnable that maps a list of inputs to a list of outputs,
Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input. by calling invoke() with each input.
Returns: Returns:
@ -1906,7 +1894,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses.
"""
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config) callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -1955,7 +1944,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses.
"""
config = ensure_config(config) config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config) callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -2004,7 +1994,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses.
"""
if not input: if not input:
return [] return []
@ -2076,7 +2067,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> list[Output]: ) -> list[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses.
"""
if not input: if not input:
return [] return []
@ -2149,7 +2141,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Iterator[Output]: ) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of """Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks. Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses.""" Use this to implement `stream()` or `transform()` in Runnable subclasses.
"""
# Mixin that is used by both astream log and astream events implementation # Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
@ -2249,7 +2242,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async """Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks. Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" Use this to implement `astream()` or `atransform()` in Runnable subclasses.
"""
# Mixin that is used by both astream log and astream events implementation # Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
@ -5601,7 +5595,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
- ``with_fallbacks``: Bind a fallback policy to the underlying Runnable. - ``with_fallbacks``: Bind a fallback policy to the underlying Runnable.
Example: Example:
`bind`: Bind kwargs to pass to the underlying Runnable when running it. `bind`: Bind kwargs to pass to the underlying Runnable when running it.
.. code-block:: python .. code-block:: python

View File

@ -116,7 +116,7 @@ var_child_runnable_config = ContextVar(
def _set_config_context(config: RunnableConfig) -> None: def _set_config_context(config: RunnableConfig) -> None:
"""Set the child Runnable config + tracing context """Set the child Runnable config + tracing context.
Args: Args:
config (RunnableConfig): The config to set. config (RunnableConfig): The config to set.

View File

@ -404,7 +404,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
self, **kwargs: AnyConfigurableField self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]: ) -> RunnableSerializable[Input, Output]:
"""Get a new RunnableConfigurableFields with the specified """Get a new RunnableConfigurableFields with the specified
configurable fields.""" configurable fields.
"""
return self.default.configurable_fields(**{**self.fields, **kwargs}) return self.default.configurable_fields(**{**self.fields, **kwargs})
def _prepare( def _prepare(

View File

@ -137,7 +137,7 @@ class Branch(NamedTuple):
class CurveStyle(Enum): class CurveStyle(Enum):
"""Enum for different curve styles supported by Mermaid""" """Enum for different curve styles supported by Mermaid."""
BASIS = "basis" BASIS = "basis"
BUMP_X = "bumpX" BUMP_X = "bumpX"
@ -169,7 +169,7 @@ class NodeStyles:
class MermaidDrawMethod(Enum): class MermaidDrawMethod(Enum):
"""Enum for different draw methods supported by Mermaid""" """Enum for different draw methods supported by Mermaid."""
PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph
API = "api" # Uses Mermaid.INK API to render the graph API = "api" # Uses Mermaid.INK API to render the graph
@ -306,7 +306,8 @@ class Graph:
def next_id(self) -> str: def next_id(self) -> str:
"""Return a new unique node """Return a new unique node
identifier that can be used to add a node to the graph.""" identifier that can be used to add a node to the graph.
"""
return uuid4().hex return uuid4().hex
def add_node( def add_node(
@ -422,7 +423,8 @@ class Graph:
def reid(self) -> Graph: def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified, """Return a new graph with all nodes re-identified,
using their unique, readable names where possible.""" using their unique, readable names where possible.
"""
node_name_to_ids = defaultdict(list) node_name_to_ids = defaultdict(list)
for node in self.nodes.values(): for node in self.nodes.values():
node_name_to_ids[node.name].append(node.id) node_name_to_ids[node.name].append(node.id)
@ -457,18 +459,21 @@ class Graph:
def first_node(self) -> Optional[Node]: def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge. """Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin.""" When drawing the graph, this node would be the origin.
"""
return _first_node(self) return _first_node(self)
def last_node(self) -> Optional[Node]: def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge. """Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination.""" When drawing the graph, this node would be the destination.
"""
return _last_node(self) return _last_node(self)
def trim_first_node(self) -> None: def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge, """Remove the first node if it exists and has a single outgoing edge,
i.e., if removing it would not leave the graph without a "first" node.""" i.e., if removing it would not leave the graph without a "first" node.
"""
first_node = self.first_node() first_node = self.first_node()
if ( if (
first_node first_node
@ -479,7 +484,8 @@ class Graph:
def trim_last_node(self) -> None: def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge, """Remove the last node if it exists and has a single incoming edge,
i.e., if removing it would not leave the graph without a "last" node.""" i.e., if removing it would not leave the graph without a "last" node.
"""
last_node = self.last_node() last_node = self.last_node()
if ( if (
last_node last_node
@ -634,7 +640,8 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a target of any edge. """Find the single node that is not a target of any edge.
Exclude nodes/sources with ids in the exclude list. Exclude nodes/sources with ids in the exclude list.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin.""" When drawing the graph, this node would be the origin.
"""
targets = {edge.target for edge in graph.edges if edge.source not in exclude} targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: list[Node] = [] found: list[Node] = []
for node in graph.nodes.values(): for node in graph.nodes.values():
@ -647,7 +654,8 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a source of any edge. """Find the single node that is not a source of any edge.
Exclude nodes/targets with ids in the exclude list. Exclude nodes/targets with ids in the exclude list.
If there is no such node, or there are multiple, return None. If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination.""" When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in graph.edges if edge.target not in exclude} sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: list[Node] = [] found: list[Node] = []
for node in graph.nodes.values(): for node in graph.nodes.values():

View File

@ -1,5 +1,6 @@
"""Draws DAG in ASCII. """Draws DAG in ASCII.
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py""" Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
"""
import math import math
import os import os
@ -239,7 +240,6 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
| 1 | | 1 |
+---+ +---+
""" """
# NOTE: coordinates might me negative, so we need to shift # NOTE: coordinates might me negative, so we need to shift
# everything to the positive plane before we actually draw it. # everything to the positive plane before we actually draw it.
xlist = [] xlist = []

View File

@ -132,7 +132,6 @@ class PngDrawer:
:param graph: The graph to draw :param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned. :param output_path: The path to save the PNG. If None, PNG bytes are returned.
""" """
try: try:
import pygraphviz as pgv # type: ignore[import] import pygraphviz as pgv # type: ignore[import]
except ImportError as exc: except ImportError as exc:

View File

@ -43,7 +43,6 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
way to use it is through the `.with_retry()` method on all Runnables. way to use it is through the `.with_retry()` method on all Runnables.
Example: Example:
Here's an example that uses a RunnableLambda to raise an exception Here's an example that uses a RunnableLambda to raise an exception
.. code-block:: python .. code-block:: python

View File

@ -44,8 +44,7 @@ class RouterInput(TypedDict):
class RouterRunnable(RunnableSerializable[RouterInput, Output]): class RouterRunnable(RunnableSerializable[RouterInput, Output]):
""" """Runnable that routes to a set of Runnables based on Input['key'].
Runnable that routes to a set of Runnables based on Input['key'].
Returns the output of the selected Runnable. Returns the output of the selected Runnable.
Parameters: Parameters:

View File

@ -462,9 +462,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
class AddableDict(dict[str, Any]): class AddableDict(dict[str, Any]):
""" """Dictionary that can be added to another dictionary."""
Dictionary that can be added to another dictionary.
"""
def __add__(self, other: AddableDict) -> AddableDict: def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self) chunk = AddableDict(self)

View File

@ -162,7 +162,6 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add) tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3 tool.run(1, 2) # 3
""" """
if func is not None: if func is not None:
source_function = func source_function = func
elif coroutine is not None: elif coroutine is not None:

View File

@ -531,8 +531,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
"""Persist a run.""" """Persist a run."""
async def _start_trace(self, run: Run) -> None: async def _start_trace(self, run: Run) -> None:
""" """Start a trace for a run.
Start a trace for a run.
Starting a trace will run concurrently with each _on_[run_type]_start method. Starting a trace will run concurrently with each _on_[run_type]_start method.
No _on_[run_type]_start callback should depend on operations in _start_trace. No _on_[run_type]_start callback should depend on operations in _start_trace.
@ -541,8 +540,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
await self._on_run_create(run) await self._on_run_create(run)
async def _end_trace(self, run: Run) -> None: async def _end_trace(self, run: Run) -> None:
""" """End a trace for a run.
End a trace for a run.
Ending a trace will run concurrently with each _on_[run_type]_end method. Ending a trace will run concurrently with each _on_[run_type]_end method.
No _on_[run_type]_end callback should depend on operations in _end_trace. No _on_[run_type]_end callback should depend on operations in _end_trace.

View File

@ -40,8 +40,7 @@ SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
class _TracerCore(ABC): class _TracerCore(ABC):
""" """Abstract base class for tracers.
Abstract base class for tracers.
This class provides common methods, and reusable methods for tracers. This class provides common methods, and reusable methods for tracers.
""" """
@ -233,9 +232,7 @@ class _TracerCore(ABC):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
""" """Append token event to LLM run and return the run."""
Append token event to LLM run and return the run.
"""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"}) llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: dict[str, Any] = {"token": token} event_kwargs: dict[str, Any] = {"token": token}
if chunk: if chunk:

View File

@ -35,35 +35,33 @@ def wait_for_all_evaluators() -> None:
class EvaluatorCallbackHandler(BaseTracer): class EvaluatorCallbackHandler(BaseTracer):
"""Tracer that runs a run evaluator whenever a run is persisted. """Tracer that runs a run evaluator whenever a run is persisted.
Parameters Args:
---------- evaluators : Sequence[RunEvaluator]
evaluators : Sequence[RunEvaluator] The run evaluators to apply to all top level runs.
The run evaluators to apply to all top level runs. client : LangSmith Client, optional
client : LangSmith Client, optional The LangSmith client instance to use for evaluating the runs.
The LangSmith client instance to use for evaluating the runs. If not specified, a new instance will be created.
If not specified, a new instance will be created. example_id : Union[UUID, str], optional
example_id : Union[UUID, str], optional The example ID to be associated with the runs.
The example ID to be associated with the runs. project_name : str, optional
project_name : str, optional The LangSmith project name to be organize eval chain runs under.
The LangSmith project name to be organize eval chain runs under.
Attributes Attributes:
---------- example_id : Union[UUID, None]
example_id : Union[UUID, None] The example ID associated with the runs.
The example ID associated with the runs. client : Client
client : Client The LangSmith client instance used for evaluating the runs.
The LangSmith client instance used for evaluating the runs. evaluators : Sequence[RunEvaluator]
evaluators : Sequence[RunEvaluator] The sequence of run evaluators to be executed.
The sequence of run evaluators to be executed. executor : ThreadPoolExecutor
executor : ThreadPoolExecutor The thread pool executor used for running the evaluators.
The thread pool executor used for running the evaluators. futures : Set[Future]
futures : Set[Future] The set of futures representing the running evaluators.
The set of futures representing the running evaluators. skip_unfinished : bool
skip_unfinished : bool Whether to skip runs that are not finished or raised
Whether to skip runs that are not finished or raised an error.
an error. project_name : Optional[str]
project_name : Optional[str] The LangSmith project name to be organize eval chain runs under.
The LangSmith project name to be organize eval chain runs under.
""" """
name: str = "evaluator_callback_handler" name: str = "evaluator_callback_handler"

View File

@ -253,9 +253,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
""" """Append token event to LLM run and return the run."""
Append token event to LLM run and return the run.
"""
return super()._llm_run_with_token_event( return super()._llm_run_with_token_event(
# Drop the chunk; we don't need to save it # Drop the chunk; we don't need to save it
token, token,

View File

@ -24,8 +24,7 @@ class RunCollectorCallbackHandler(BaseTracer):
def __init__( def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
) -> None: ) -> None:
""" """Initialize the RunCollectorCallbackHandler.
Initialize the RunCollectorCallbackHandler.
Parameters Parameters
---------- ----------
@ -41,8 +40,7 @@ class RunCollectorCallbackHandler(BaseTracer):
self.traced_runs: list[Run] = [] self.traced_runs: list[Run] = []
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
""" """Persist a run by adding it to the traced_runs list.
Persist a run by adding it to the traced_runs list.
Parameters Parameters
---------- ----------

View File

@ -1,5 +1,4 @@
""" """**Utility functions** for LangChain.
**Utility functions** for LangChain.
These functions do not depend on any other LangChain module. These functions do not depend on any other LangChain module.
""" """

View File

@ -1,7 +1,6 @@
""" """Adapted from
Adapted from
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License MIT License.
""" """
from collections import deque from collections import deque
@ -54,7 +53,6 @@ def py_anext(
Raises: Raises:
TypeError: If the iterator is not an async iterator. TypeError: If the iterator is not an async iterator.
""" """
try: try:
__anext__ = cast( __anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__ Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
@ -147,8 +145,7 @@ async def tee_peer(
class Tee(Generic[T]): class Tee(Generic[T]):
""" """Create ``n`` separate asynchronous iterators over ``iterable``.
Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order. the same items in the same order.

View File

@ -1,4 +1,4 @@
"""Methods for creating function specs in the style of OpenAI Functions""" """Methods for creating function specs in the style of OpenAI Functions."""
from __future__ import annotations from __future__ import annotations
@ -342,6 +342,7 @@ def convert_to_openai_function(
strict: Optional[bool] = None, strict: Optional[bool] = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI function. """Convert a raw function/class to an OpenAI function.
Args: Args:
function: function:
A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain

View File

@ -70,6 +70,7 @@ def extract_sub_links(
exclude_prefixes: Exclude any URLs that start with one of these prefixes. exclude_prefixes: Exclude any URLs that start with one of these prefixes.
continue_on_failure: If True, continue if parsing a specific link raises an continue_on_failure: If True, continue if parsing a specific link raises an
exception. Otherwise, raise the exception. exception. Otherwise, raise the exception.
Returns: Returns:
List[str]: sub links. List[str]: sub links.
""" """

View File

@ -83,8 +83,7 @@ def tee_peer(
class Tee(Generic[T]): class Tee(Generic[T]):
""" """Create ``n`` separate asynchronous iterators over ``iterable``.
Create ``n`` separate asynchronous iterators over ``iterable``
This splits a single ``iterable`` into multiple iterators, each providing This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order. the same items in the same order.

View File

@ -18,11 +18,10 @@ def _replace_new_line(match: re.Match[str]) -> str:
def _custom_parser(multiline_string: str) -> str: def _custom_parser(multiline_string: str) -> str:
""" """The LLM response for `action_input` may be a multiline
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts. replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`) (newlines in JSON must be double-escaped: `\\n`).
""" """
if isinstance(multiline_string, (bytes, bytearray)): if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode() multiline_string = multiline_string.decode()
@ -161,8 +160,7 @@ def _parse_json(
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
""" """Parse a JSON string from a Markdown string and check that it
Parse a JSON string from a Markdown string and check that it
contains the expected keys. contains the expected keys.
Args: Args:

View File

@ -105,7 +105,6 @@ def dereference_refs(
Returns: Returns:
The dereferenced schema object. The dereferenced schema object.
""" """
full_schema = full_schema or schema_obj full_schema = full_schema or schema_obj
skip_keys = ( skip_keys = (
skip_keys skip_keys

View File

@ -1,6 +1,5 @@
""" """Adapted from https://github.com/noahmorrison/chevron
Adapted from https://github.com/noahmorrison/chevron MIT License.
MIT License
""" """
from __future__ import annotations from __future__ import annotations
@ -48,7 +47,6 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
Returns: Returns:
Tuple[str, str]: The literal and the template. Tuple[str, str]: The literal and the template.
""" """
global _CURRENT_LINE global _CURRENT_LINE
try: try:
@ -74,7 +72,6 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
Returns: Returns:
bool: Whether the tag could be a standalone. bool: Whether the tag could be a standalone.
""" """
# If there is a newline, or the previous tag was a standalone # If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone: if literal.find("\n") != -1 or is_standalone:
padding = literal.split("\n")[-1] padding = literal.split("\n")[-1]
@ -98,7 +95,6 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
Returns: Returns:
bool: Whether the tag could be a standalone. bool: Whether the tag could be a standalone.
""" """
# Check right side if we might be a standalone # Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]: if is_standalone and tag_type not in ["variable", "no escape"]:
on_newline = template.split("\n", 1) on_newline = template.split("\n", 1)
@ -199,36 +195,25 @@ def tokenize(
using file-like objects. It also accepts a string containing using file-like objects. It also accepts a string containing
the template. the template.
Args:
Arguments: template: a file-like object, or a string of a mustache template
def_ldel: The default left delimiter
template -- a file-like object, or a string of a mustache template ("{{" by default, as in spec compliant mustache)
def_rdel: The default right delimiter
def_ldel -- The default left delimiter ("}}" by default, as in spec compliant mustache)
("{{" by default, as in spec compliant mustache)
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
Returns: Returns:
A generator of mustache tags in the form of a tuple (tag_type, tag_key)
A generator of mustache tags in the form of a tuple Where tag_type is one of:
* literal
-- (tag_type, tag_key) * section
* inverted section
Where tag_type is one of: * end
* literal * partial
* section * no escape
* inverted section And tag_key is either the key or in the case of a literal tag,
* end the literal itself.
* partial
* no escape
And tag_key is either the key or in the case of a literal tag,
the literal itself.
""" """
global _CURRENT_LINE, _LAST_TAG_LINE global _CURRENT_LINE, _LAST_TAG_LINE
_CURRENT_LINE = 1 _CURRENT_LINE = 1
_LAST_TAG_LINE = None _LAST_TAG_LINE = None
@ -329,8 +314,7 @@ def tokenize(
def _html_escape(string: str) -> str: def _html_escape(string: str) -> str:
"""HTML escape all of these " & < >""" """HTML escape all of these " & < >."""
html_codes = { html_codes = {
'"': "&quot;", '"': "&quot;",
"<": "&lt;", "<": "&lt;",
@ -352,8 +336,7 @@ def _get_key(
def_ldel: str, def_ldel: str,
def_rdel: str, def_rdel: str,
) -> Any: ) -> Any:
"""Get a key from the current scope""" """Get a key from the current scope."""
# If the key is a dot # If the key is a dot
if key == ".": if key == ".":
# Then just return the current scope # Then just return the current scope
@ -410,7 +393,7 @@ def _get_key(
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str: def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
"""Load a partial""" """Load a partial."""
try: try:
# Maybe the partial is in the dictionary # Maybe the partial is in the dictionary
return partials_dict[name] return partials_dict[name]
@ -441,45 +424,31 @@ def render(
Renders a mustache template with a data scope and inline partial capability. Renders a mustache template with a data scope and inline partial capability.
Arguments: Args:
template: A file-like object or a string containing the template.
template -- A file-like object or a string containing the template. data: A python dictionary with your data scope.
partials_path: The path to where your partials are stored.
data -- A python dictionary with your data scope. If set to None, then partials won't be loaded from the file system
(defaults to '.').
partials_path -- The path to where your partials are stored. partials_ext: The extension that you want the parser to look for
If set to None, then partials won't be loaded from the file system (defaults to 'mustache').
(defaults to '.'). partials_dict: A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
partials_ext -- The extension that you want the parser to look for as a file called include.mustache
(defaults to 'mustache'). (defaults to {}).
padding: This is for padding partials, and shouldn't be used
partials_dict -- A python dictionary which will be search for partials (but can be if you really want to).
before the filesystem is. {'include': 'foo'} is the same def_ldel: The default left delimiter
as a file called include.mustache ("{{" by default, as in spec compliant mustache).
(defaults to {}). def_rdel: The default right delimiter
("}}" by default, as in spec compliant mustache).
padding -- This is for padding partials, and shouldn't be used scopes: The list of scopes that get_key will look through.
(but can be if you really want to). warn: Log a warning when a template substitution isn't found in the data
keep: Keep unreplaced tags when a substitution isn't found in the data.
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache).
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache).
scopes -- The list of scopes that get_key will look through.
warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data.
Returns: Returns:
A string containing the rendered template.
A string containing the rendered template.
""" """
# If the template is a sequence but not derived from a string # If the template is a sequence but not derived from a string
if isinstance(template, Sequence) and not isinstance(template, str): if isinstance(template, Sequence) and not isinstance(template, str):
# Then we don't need to tokenize it # Then we don't need to tokenize it

View File

@ -172,7 +172,6 @@ def pre_init(func: Callable) -> Any:
Returns: Returns:
Any: The decorated function. Any: The decorated function.
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning) warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)

View File

@ -20,7 +20,7 @@ from langchain_core.utils.pydantic import (
def xor_args(*arg_groups: tuple[str, ...]) -> Callable: def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive." """Validate specified keyword args are mutually exclusive.".
Args: Args:
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args. *arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.

View File

@ -132,7 +132,6 @@ class VectorStore(ABC):
Optional[bool]: True if deletion is successful, Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented. False otherwise, None if not implemented.
""" """
msg = "delete method must be implemented by subclass." msg = "delete method must be implemented by subclass."
raise NotImplementedError(msg) raise NotImplementedError(msg)
@ -423,7 +422,6 @@ class VectorStore(ABC):
@staticmethod @staticmethod
def _cosine_relevance_score_fn(distance: float) -> float: def _cosine_relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1].""" """Normalize the distance to a score on a scale [0, 1]."""
return 1.0 - distance return 1.0 - distance
@staticmethod @staticmethod
@ -435,8 +433,7 @@ class VectorStore(ABC):
return -1.0 * distance return -1.0 * distance
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
""" """The 'correct' relevance function
The 'correct' relevance function
may differ depending on a few things, including: may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore - the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!) - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
@ -473,7 +470,6 @@ class VectorStore(ABC):
Returns: Returns:
List of Tuples of (doc, similarity_score). List of Tuples of (doc, similarity_score).
""" """
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
@ -487,8 +483,7 @@ class VectorStore(ABC):
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
""" """Default similarity search with relevance scores. Modify if necessary
Default similarity search with relevance scores. Modify if necessary
in subclass. in subclass.
Return docs and relevance scores in the range [0, 1]. Return docs and relevance scores in the range [0, 1].
@ -514,8 +509,7 @@ class VectorStore(ABC):
k: int = 4, k: int = 4,
**kwargs: Any, **kwargs: Any,
) -> list[tuple[Document, float]]: ) -> list[tuple[Document, float]]:
""" """Default similarity search with relevance scores. Modify if necessary
Default similarity search with relevance scores. Modify if necessary
in subclass. in subclass.
Return docs and relevance scores in the range [0, 1]. Return docs and relevance scores in the range [0, 1].
@ -644,7 +638,6 @@ class VectorStore(ABC):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
""" """
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
@ -678,7 +671,6 @@ class VectorStore(ABC):
Returns: Returns:
List of Documents most similar to the query vector. List of Documents most similar to the query vector.
""" """
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
@ -741,7 +733,6 @@ class VectorStore(ABC):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
# This is a temporary workaround to make the similarity search # This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search # asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations. # asynchronous in the vector store implementations.
@ -1056,7 +1047,6 @@ class VectorStoreRetriever(BaseRetriever):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing.""" """Get standard params for tracing."""
_kwargs = self.search_kwargs | kwargs _kwargs = self.search_kwargs | kwargs
ls_params = super()._get_ls_params(**_kwargs) ls_params = super()._get_ls_params(**_kwargs)

View File

@ -55,46 +55,46 @@ def test_warn_beta(kwargs: dict[str, Any], expected_message: str) -> None:
@beta() @beta()
def beta_function() -> str: def beta_function() -> str:
"""original doc""" """Original doc."""
return "This is a beta function." return "This is a beta function."
@beta() @beta()
async def beta_async_function() -> str: async def beta_async_function() -> str:
"""original doc""" """Original doc."""
return "This is a beta async function." return "This is a beta async function."
class ClassWithBetaMethods: class ClassWithBetaMethods:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@beta() @beta()
def beta_method(self) -> str: def beta_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta method." return "This is a beta method."
@beta() @beta()
async def beta_async_method(self) -> str: async def beta_async_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta async method." return "This is a beta async method."
@classmethod @classmethod
@beta() @beta()
def beta_classmethod(cls) -> str: def beta_classmethod(cls) -> str:
"""original doc""" """Original doc."""
return "This is a beta classmethod." return "This is a beta classmethod."
@staticmethod @staticmethod
@beta() @beta()
def beta_staticmethod() -> str: def beta_staticmethod() -> str:
"""original doc""" """Original doc."""
return "This is a beta staticmethod." return "This is a beta staticmethod."
@property @property
@beta() @beta()
def beta_property(self) -> str: def beta_property(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta property." return "This is a beta property."
@ -240,11 +240,11 @@ def test_whole_class_beta() -> None:
@beta() @beta()
class BetaClass: class BetaClass:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@beta() @beta()
def beta_method(self) -> str: def beta_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta method." return "This is a beta method."
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -281,14 +281,14 @@ def test_whole_class_inherited_beta() -> None:
class BetaClass: class BetaClass:
@beta() @beta()
def beta_method(self) -> str: def beta_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta method." return "This is a beta method."
@beta() @beta()
class InheritedBetaClass(BetaClass): class InheritedBetaClass(BetaClass):
@beta() @beta()
def beta_method(self) -> str: def beta_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta method 2." return "This is a beta method 2."
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -344,7 +344,7 @@ def test_whole_class_inherited_beta() -> None:
class MyModel(BaseModel): class MyModel(BaseModel):
@beta() @beta()
def beta_method(self) -> str: def beta_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a beta method." return "This is a beta method."

View File

@ -75,46 +75,46 @@ def test_undefined_deprecation_schedule() -> None:
@deprecated(since="2.0.0", removal="3.0.0", pending=False) @deprecated(since="2.0.0", removal="3.0.0", pending=False)
def deprecated_function() -> str: def deprecated_function() -> str:
"""original doc""" """Original doc."""
return "This is a deprecated function." return "This is a deprecated function."
@deprecated(since="2.0.0", removal="3.0.0", pending=False) @deprecated(since="2.0.0", removal="3.0.0", pending=False)
async def deprecated_async_function() -> str: async def deprecated_async_function() -> str:
"""original doc""" """Original doc."""
return "This is a deprecated async function." return "This is a deprecated async function."
class ClassWithDeprecatedMethods: class ClassWithDeprecatedMethods:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str: def deprecated_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated method." return "This is a deprecated method."
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
async def deprecated_async_method(self) -> str: async def deprecated_async_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated async method." return "This is a deprecated async method."
@classmethod @classmethod
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_classmethod(cls) -> str: def deprecated_classmethod(cls) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated classmethod." return "This is a deprecated classmethod."
@staticmethod @staticmethod
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_staticmethod() -> str: def deprecated_staticmethod() -> str:
"""original doc""" """Original doc."""
return "This is a deprecated staticmethod." return "This is a deprecated staticmethod."
@property @property
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_property(self) -> str: def deprecated_property(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated property." return "This is a deprecated property."
@ -264,11 +264,11 @@ def test_whole_class_deprecation() -> None:
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
class DeprecatedClass: class DeprecatedClass:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str: def deprecated_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated method." return "This is a deprecated method."
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -306,11 +306,11 @@ def test_whole_class_inherited_deprecation() -> None:
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
class DeprecatedClass: class DeprecatedClass:
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str: def deprecated_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated method." return "This is a deprecated method."
@deprecated(since="2.2.0", removal="3.2.0") @deprecated(since="2.2.0", removal="3.2.0")
@ -318,11 +318,11 @@ def test_whole_class_inherited_deprecation() -> None:
"""Inherited deprecated class.""" """Inherited deprecated class."""
def __init__(self) -> None: def __init__(self) -> None:
"""original doc""" """Original doc."""
@deprecated(since="2.2.0", removal="3.2.0") @deprecated(since="2.2.0", removal="3.2.0")
def deprecated_method(self) -> str: def deprecated_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated method." return "This is a deprecated method."
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -379,7 +379,7 @@ def test_whole_class_inherited_deprecation() -> None:
class MyModel(BaseModel): class MyModel(BaseModel):
@deprecated(since="2.0.0", removal="3.0.0") @deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str: def deprecated_method(self) -> str:
"""original doc""" """Original doc."""
return "This is a deprecated method." return "This is a deprecated method."
@ -408,7 +408,7 @@ def test_raise_error_for_bad_decorator() -> None:
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello") @deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
def deprecated_function() -> str: def deprecated_function() -> str:
"""original doc""" """Original doc."""
return "This is a deprecated function." return "This is a deprecated function."
@ -417,7 +417,7 @@ def test_rename_parameter() -> None:
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name") @rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
def foo(new_name: str) -> str: def foo(new_name: str) -> str:
"""original doc""" """Original doc."""
return new_name return new_name
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -427,7 +427,7 @@ def test_rename_parameter() -> None:
assert foo(new_name="hello") == "hello" assert foo(new_name="hello") == "hello"
assert foo("hello") == "hello" assert foo("hello") == "hello"
assert foo.__doc__ == "original doc" assert foo.__doc__ == "Original doc."
with pytest.raises(TypeError): with pytest.raises(TypeError):
foo(meow="hello") # type: ignore[call-arg] foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -442,7 +442,7 @@ async def test_rename_parameter_for_async_func() -> None:
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name") @rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
async def foo(new_name: str) -> str: async def foo(new_name: str) -> str:
"""original doc""" """Original doc."""
return new_name return new_name
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
@ -451,7 +451,7 @@ async def test_rename_parameter_for_async_func() -> None:
assert len(warning_list) == 1 assert len(warning_list) == 1
assert await foo(new_name="hello") == "hello" assert await foo(new_name="hello") == "hello"
assert await foo("hello") == "hello" assert await foo("hello") == "hello"
assert foo.__doc__ == "original doc" assert foo.__doc__ == "Original doc."
with pytest.raises(TypeError): with pytest.raises(TypeError):
await foo(meow="hello") # type: ignore[call-arg] await foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError): with pytest.raises(TypeError):

View File

@ -91,7 +91,6 @@ async def test_async_custom_event_implicit_config() -> None:
async def test_async_callback_manager() -> None: async def test_async_callback_manager() -> None:
"""Test async callback manager.""" """Test async callback manager."""
callback = AsyncCustomCallbackHandler() callback = AsyncCustomCallbackHandler()
run_id = uuid.UUID(int=7) run_id = uuid.UUID(int=7)

View File

@ -2,8 +2,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding
def test_deterministic_fake_embeddings() -> None: def test_deterministic_fake_embeddings() -> None:
""" """Test that the deterministic fake embeddings return the same
Test that the deterministic fake embeddings return the same
embedding vector for the same text. embedding vector for the same text.
""" """
fake = DeterministicFakeEmbedding(size=10) fake = DeterministicFakeEmbedding(size=10)

View File

@ -1,4 +1,4 @@
"""Test in memory indexer""" """Test in memory indexer."""
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator, Generator

View File

@ -1167,7 +1167,7 @@ def test_incremental_delete_with_same_source(
def test_incremental_indexing_with_batch_size( def test_incremental_indexing_with_batch_size(
record_manager: InMemoryRecordManager, vector_store: InMemoryVectorStore record_manager: InMemoryRecordManager, vector_store: InMemoryVectorStore
) -> None: ) -> None:
"""Test indexing with incremental indexing""" """Test indexing with incremental indexing."""
loader = ToyLoader( loader = ToyLoader(
documents=[ documents=[
Document( Document(
@ -2031,7 +2031,6 @@ def test_index_with_upsert_kwargs_for_document_indexer(
mocker: MockerFixture, mocker: MockerFixture,
) -> None: ) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer.""" """Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex() document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "upsert") upsert_spy = mocker.spy(document_index.__class__, "upsert")
docs = [ docs = [
@ -2070,7 +2069,6 @@ async def test_aindex_with_upsert_kwargs_for_document_indexer(
mocker: MockerFixture, mocker: MockerFixture,
) -> None: ) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer.""" """Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex() document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "aupsert") upsert_spy = mocker.spy(document_index.__class__, "aupsert")
docs = [ docs = [

View File

@ -136,7 +136,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
message = AIMessage(content="hello") message = AIMessage(content="hello")
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
@ -164,7 +164,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
def _stream( def _stream(
@ -209,7 +209,7 @@ async def test_astream_implementation_uses_astream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
async def _astream( # type: ignore async def _astream( # type: ignore
@ -243,7 +243,6 @@ class FakeTracer(BaseTracer):
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
self.traced_run_ids.append(run.id) self.traced_run_ids.append(run.id)

View File

@ -8,7 +8,6 @@ from langchain_core.rate_limiters import InMemoryRateLimiter
def test_rate_limit_invoke() -> None: def test_rate_limit_invoke() -> None:
"""Add rate limiter.""" """Add rate limiter."""
model = GenericFakeChatModel( model = GenericFakeChatModel(
messages=iter(["hello", "world"]), messages=iter(["hello", "world"]),
rate_limiter=InMemoryRateLimiter( rate_limiter=InMemoryRateLimiter(
@ -35,7 +34,6 @@ def test_rate_limit_invoke() -> None:
async def test_rate_limit_ainvoke() -> None: async def test_rate_limit_ainvoke() -> None:
"""Add rate limiter.""" """Add rate limiter."""
model = GenericFakeChatModel( model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]), messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter( rate_limiter=InMemoryRateLimiter(

View File

@ -160,7 +160,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
def _stream( def _stream(
@ -197,7 +197,7 @@ async def test_astream_implementation_uses_astream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call."""
raise NotImplementedError raise NotImplementedError
async def _astream( async def _astream(

View File

@ -29,7 +29,8 @@ def test_single_item() -> None:
def test_multiple_items_with_spaces() -> None: def test_multiple_items_with_spaces() -> None:
"""Test that a string with multiple comma-separated items """Test that a string with multiple comma-separated items
with spaces is parsed to a list.""" with spaces is parsed to a list.
"""
parser = CommaSeparatedListOutputParser() parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz" text = "foo, bar, baz"
expected = ["foo", "bar", "baz"] expected = ["foo", "bar", "baz"]
@ -66,7 +67,8 @@ def test_multiple_items() -> None:
def test_multiple_items_with_comma() -> None: def test_multiple_items_with_comma() -> None:
"""Test that a string with multiple comma-separated items with 1 item containing a """Test that a string with multiple comma-separated items with 1 item containing a
comma is parsed to a list.""" comma is parsed to a list.
"""
parser = CommaSeparatedListOutputParser() parser = CommaSeparatedListOutputParser()
text = '"foo, foo2",bar,baz' text = '"foo, foo2",bar,baz'
expected = ["foo, foo2", "bar", "baz"] expected = ["foo, foo2", "bar", "baz"]

View File

@ -166,7 +166,6 @@ def test_pydantic_output_functions_parser() -> None:
def test_pydantic_output_functions_parser_multiple_schemas() -> None: def test_pydantic_output_functions_parser_multiple_schemas() -> None:
"""Test that the parser works if providing multiple pydantic schemas.""" """Test that the parser works if providing multiple pydantic schemas."""
message = AIMessage( message = AIMessage(
content="This is a test message", content="This is a test message",
additional_kwargs={ additional_kwargs={

View File

@ -482,7 +482,7 @@ class Person(BaseModel):
class NameCollector(BaseModel): class NameCollector(BaseModel):
"""record names of all people mentioned""" """record names of all people mentioned."""
names: list[str] = Field(..., description="all names mentioned") names: list[str] = Field(..., description="all names mentioned")
person: Person = Field(..., description="info about the main subject") person: Person = Field(..., description="info about the main subject")

View File

@ -1,4 +1,4 @@
"""Test PydanticOutputParser""" """Test PydanticOutputParser."""
from enum import Enum from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional
@ -141,7 +141,6 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None: def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser.""" """Test PydanticOutputParser."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser( pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel pydantic_object=TestModel
) )
@ -154,7 +153,6 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None: def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation.""" """Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser( pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel pydantic_object=TestModel
) )

View File

@ -1,4 +1,4 @@
"""Test XMLOutputParser""" """Test XMLOutputParser."""
import importlib import importlib
from collections.abc import AsyncIterator, Iterable from collections.abc import AsyncIterator, Iterable
@ -77,7 +77,7 @@ async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
async def test_root_only_xml_output_parser() -> None: async def test_root_only_xml_output_parser() -> None:
"""Test XMLOutputParser when xml only contains the root level tag""" """Test XMLOutputParser when xml only contains the root level tag."""
xml_parser = XMLOutputParser(parser="xml") xml_parser = XMLOutputParser(parser="xml")
assert xml_parser.parse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."} assert xml_parser.parse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
assert await xml_parser.aparse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."} assert await xml_parser.aparse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
@ -125,7 +125,6 @@ async def test_xml_output_parser_defused(content: str) -> None:
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"]) @pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None: def test_xml_output_parser_fail(result: str) -> None:
"""Test XMLOutputParser where complete output is not in XML format.""" """Test XMLOutputParser where complete output is not in XML format."""
xml_parser = XMLOutputParser(parser="xml") xml_parser = XMLOutputParser(parser="xml")
with pytest.raises(OutputParserException) as e: with pytest.raises(OutputParserException) as e:

View File

@ -110,7 +110,6 @@ def test_create_chat_prompt_template_from_template_partial() -> None:
def test_create_system_message_prompt_template_from_template_partial() -> None: def test_create_system_message_prompt_template_from_template_partial() -> None:
"""Create a system message prompt template with partials.""" """Create a system message prompt template with partials."""
graph_creator_content = """ graph_creator_content = """
Your instructions are: Your instructions are:
{instructions} {instructions}

View File

@ -348,7 +348,8 @@ def test_prompt_from_file() -> None:
def test_prompt_from_file_with_partial_variables() -> None: def test_prompt_from_file_with_partial_variables() -> None:
"""Test prompt can be successfully constructed from a file """Test prompt can be successfully constructed from a file
with partial variables.""" with partial variables.
"""
# given # given
template = "This is a {foo} test {bar}." template = "This is a {foo} test {bar}."
partial_variables = {"bar": "baz"} partial_variables = {"bar": "baz"}

View File

@ -75,7 +75,6 @@ def _remove_enum(obj: Any) -> None:
def _schema(obj: Any) -> dict: def _schema(obj: Any) -> dict:
"""Return the schema of the object.""" """Return the schema of the object."""
if not is_basemodel_subclass(obj): if not is_basemodel_subclass(obj):
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}" msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
raise TypeError(msg) raise TypeError(msg)

View File

@ -67,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def test_doubly_set_configurable() -> None: def test_doubly_set_configurable() -> None:
"""Test that setting a configurable field with a default value works""" """Test that setting a configurable field with a default value works."""
runnable = MyRunnable(my_property="a") # type: ignore runnable = MyRunnable(my_property="a") # type: ignore
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(

View File

@ -314,7 +314,7 @@ class FakeStructuredOutputModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
return ChatResult(generations=[]) return ChatResult(generations=[])
def bind_tools( def bind_tools(
@ -344,7 +344,7 @@ class FakeModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
return ChatResult(generations=[]) return ChatResult(generations=[])
def bind_tools( def bind_tools(

View File

@ -426,7 +426,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
def test_graph_mermaid_escape_node_label() -> None: def test_graph_mermaid_escape_node_label() -> None:
"""Test that node labels are correctly preprocessed for draw_mermaid""" """Test that node labels are correctly preprocessed for draw_mermaid."""
assert _escape_node_label("foo") == "foo" assert _escape_node_label("foo") == "foo"
assert _escape_node_label("foo-bar") == "foo-bar" assert _escape_node_label("foo-bar") == "foo-bar"
assert _escape_node_label("foo_1") == "foo_1" assert _escape_node_label("foo_1") == "foo_1"

View File

@ -257,7 +257,7 @@ class LengthChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call."""
return ChatResult( return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=str(len(messages))))] generations=[ChatGeneration(message=AIMessage(content=str(len(messages))))]
) )

View File

@ -96,7 +96,8 @@ PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
class FakeTracer(BaseTracer): class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution. """Fake tracer that records LangChain execution.
It replaces run ids with deterministic UUIDs for snapshotting.""" It replaces run ids with deterministic UUIDs for snapshotting.
"""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the tracer.""" """Initialize the tracer."""
@ -158,7 +159,6 @@ class FakeTracer(BaseTracer):
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
self.runs.append(self._copy_run(run)) self.runs.append(self._copy_run(run))
def flattened_runs(self) -> list[Run]: def flattened_runs(self) -> list[Run]:
@ -657,7 +657,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
def test_with_types_with_type_generics() -> None: def test_with_types_with_type_generics() -> None:
"""Verify that with_types works if we use things like List[int]""" """Verify that with_types works if we use things like List[int]."""
def foo(x: int) -> None: def foo(x: int) -> None:
"""Add one to the input.""" """Add one to the input."""
@ -3334,7 +3334,6 @@ def test_with_config_with_config() -> None:
def test_metadata_is_merged() -> None: def test_metadata_is_merged() -> None:
"""Test metadata and tags defined in with_config and at are merged/concatend.""" """Test metadata and tags defined in with_config and at are merged/concatend."""
foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}}) foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}})
expected_metadata = { expected_metadata = {
"my_key": "my_value", "my_key": "my_value",
@ -3349,7 +3348,6 @@ def test_metadata_is_merged() -> None:
def test_tags_are_appended() -> None: def test_tags_are_appended() -> None:
"""Test tags from with_config are concatenated with those in invocation.""" """Test tags from with_config are concatenated with those in invocation."""
foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]}) foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]})
with collect_runs() as cb: with collect_runs() as cb:
foo.invoke("hi", {"tags": ["invoked_key"]}) foo.invoke("hi", {"tags": ["invoked_key"]})
@ -4445,7 +4443,6 @@ async def test_runnable_branch_abatch() -> None:
def test_runnable_branch_stream() -> None: def test_runnable_branch_stream() -> None:
"""Verify that stream works for RunnableBranch.""" """Verify that stream works for RunnableBranch."""
llm_res = "i'm a textbot" llm_res = "i'm a textbot"
# sleep to better simulate a real stream # sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
@ -4503,7 +4500,6 @@ def test_runnable_branch_stream_with_callbacks() -> None:
async def test_runnable_branch_astream() -> None: async def test_runnable_branch_astream() -> None:
"""Verify that astream works for RunnableBranch.""" """Verify that astream works for RunnableBranch."""
llm_res = "i'm a textbot" llm_res = "i'm a textbot"
# sleep to better simulate a real stream # sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
@ -4694,8 +4690,8 @@ async def test_runnable_gen() -> None:
async def test_runnable_gen_context_config() -> None: async def test_runnable_gen_context_config() -> None:
"""Test that a generator can call other runnables with config """Test that a generator can call other runnables with config
propagated from the context.""" propagated from the context.
"""
fake = RunnableLambda(len) fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]: def gen(input: Iterator[Any]) -> Iterator[int]:
@ -4829,8 +4825,8 @@ async def test_runnable_gen_context_config() -> None:
async def test_runnable_iter_context_config() -> None: async def test_runnable_iter_context_config() -> None:
"""Test that a generator can call other runnables with config """Test that a generator can call other runnables with config
propagated from the context.""" propagated from the context.
"""
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
@ -4946,8 +4942,8 @@ async def test_runnable_iter_context_config() -> None:
async def test_runnable_lambda_context_config() -> None: async def test_runnable_lambda_context_config() -> None:
"""Test that a function can call other runnables with config """Test that a function can call other runnables with config
propagated from the context.""" propagated from the context.
"""
fake = RunnableLambda(len) fake = RunnableLambda(len)
@chain @chain
@ -5098,7 +5094,8 @@ def test_with_config_callbacks() -> None:
async def test_ainvoke_on_returned_runnable() -> None: async def test_ainvoke_on_returned_runnable() -> None:
"""Verify that a runnable returned by a sync runnable in the async path will """Verify that a runnable returned by a sync runnable in the async path will
be runthroughaasync path (issue #13407)""" be runthroughaasync path (issue #13407).
"""
def idchain_sync(__input: dict) -> bool: def idchain_sync(__input: dict) -> bool:
return False return False
@ -5171,7 +5168,7 @@ async def test_astream_log_deep_copies() -> None:
""" """
def _get_run_log(run_log_patches: Sequence[RunLogPatch]) -> RunLog: def _get_run_log(run_log_patches: Sequence[RunLogPatch]) -> RunLog:
"""Get run log""" """Get run log."""
run_log = RunLog(state=None) # type: ignore run_log = RunLog(state=None) # type: ignore
for log_patch in run_log_patches: for log_patch in run_log_patches:
run_log = run_log + log_patch run_log = run_log + log_patch
@ -5435,7 +5432,6 @@ def test_pydantic_protected_namespaces() -> None:
def test_schema_for_prompt_and_chat_model() -> None: def test_schema_for_prompt_and_chat_model() -> None:
"""Testing that schema is generated properly when using variable names """Testing that schema is generated properly when using variable names
that collide with pydantic attributes. that collide with pydantic attributes.
""" """
prompt = ChatPromptTemplate([("system", "{model_json_schema}, {_private}, {json}")]) prompt = ChatPromptTemplate([("system", "{model_json_schema}, {_private}, {json}")])

View File

@ -80,15 +80,15 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
async def test_event_stream_with_simple_function_tool() -> None: async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool""" """Test the event stream with a function and tool."""
def foo(x: int) -> dict: def foo(x: int) -> dict:
"""Foo""" """Foo."""
return {"x": 5} return {"x": 5}
@tool @tool
def get_docs(x: int) -> list[Document]: def get_docs(x: int) -> list[Document]:
"""Hello Doc""" """Hello Doc."""
return [Document(page_content="hello")] return [Document(page_content="hello")]
chain = RunnableLambda(foo) | get_docs chain = RunnableLambda(foo) | get_docs
@ -345,7 +345,7 @@ async def test_event_stream_with_triple_lambda() -> None:
async def test_event_stream_with_triple_lambda_test_filtering() -> None: async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"""Test filtering based on tags / names""" """Test filtering based on tags / names."""
def reverse(s: str) -> str: def reverse(s: str) -> str:
"""Reverse a string.""" """Reverse a string."""
@ -1822,7 +1822,7 @@ async def test_runnable_each() -> None:
async def test_events_astream_config() -> None: async def test_events_astream_config() -> None:
"""Test that astream events support accepting config""" """Test that astream events support accepting config."""
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")]) infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")]) good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
@ -1912,7 +1912,7 @@ async def test_runnable_with_message_history() -> None:
store: dict = {} store: dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory: def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history""" """Get a chat message history."""
if session_id not in store: if session_id not in store:
store[session_id] = [] store[session_id] = []
return InMemoryHistory(messages=store[session_id]) return InMemoryHistory(messages=store[session_id])

View File

@ -90,15 +90,15 @@ async def _collect_events(
async def test_event_stream_with_simple_function_tool() -> None: async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool""" """Test the event stream with a function and tool."""
def foo(x: int) -> dict: def foo(x: int) -> dict:
"""Foo""" """Foo."""
return {"x": 5} return {"x": 5}
@tool @tool
def get_docs(x: int) -> list[Document]: def get_docs(x: int) -> list[Document]:
"""Hello Doc""" """Hello Doc."""
return [Document(page_content="hello")] return [Document(page_content="hello")]
chain = RunnableLambda(foo) | get_docs chain = RunnableLambda(foo) | get_docs
@ -371,7 +371,7 @@ async def test_event_stream_exception() -> None:
async def test_event_stream_with_triple_lambda_test_filtering() -> None: async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"""Test filtering based on tags / names""" """Test filtering based on tags / names."""
def reverse(s: str) -> str: def reverse(s: str) -> str:
"""Reverse a string.""" """Reverse a string."""
@ -1767,7 +1767,7 @@ async def test_runnable_each() -> None:
async def test_events_astream_config() -> None: async def test_events_astream_config() -> None:
"""Test that astream events support accepting config""" """Test that astream events support accepting config."""
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")]) infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")]) good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields( model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
@ -1859,7 +1859,7 @@ async def test_runnable_with_message_history() -> None:
store: dict = {} store: dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory: def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history""" """Get a chat message history."""
if session_id not in store: if session_id not in store:
store[session_id] = [] store[session_id] = []
return InMemoryHistory(messages=store[session_id]) return InMemoryHistory(messages=store[session_id])
@ -2046,7 +2046,7 @@ async def test_sync_in_sync_lambdas() -> None:
class StreamingRunnable(Runnable[Input, Output]): class StreamingRunnable(Runnable[Input, Output]):
"""A custom runnable used for testing purposes""" """A custom runnable used for testing purposes."""
iterable: Iterable[Any] iterable: Iterable[Any]
@ -2734,7 +2734,7 @@ async def test_custom_event_root_dispatch_with_in_tool() -> None:
@tool @tool
async def foo(x: int) -> int: async def foo(x: int) -> int:
"""Foo""" """Foo."""
await adispatch_custom_event("event1", {"x": x}) await adispatch_custom_event("event1", {"x": x})
return x + 1 return x + 1

View File

@ -23,7 +23,7 @@ from langchain_core.runnables.utils import (
], ],
) )
def test_get_lambda_source(func: Callable, expected_source: str) -> None: def test_get_lambda_source(func: Callable, expected_source: str) -> None:
"""Test get_lambda_source function""" """Test get_lambda_source function."""
source = get_lambda_source(func) source = get_lambda_source(func)
assert source == expected_source assert source == expected_source
@ -36,7 +36,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
], ],
) )
def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None: def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None:
"""Test indent_lines_after_first function""" """Test indent_lines_after_first function."""
indented_text = indent_lines_after_first(text, prefix) indented_text = indent_lines_after_first(text, prefix)
assert indented_text == expected_output assert indented_text == expected_output

View File

@ -25,7 +25,6 @@ from langchain_core.messages import (
def test_serde_any_message() -> None: def test_serde_any_message() -> None:
"""Test AnyMessage() serder.""" """Test AnyMessage() serder."""
lc_objects = [ lc_objects = [
HumanMessage(content="human"), HumanMessage(content="human"),
HumanMessageChunk(content="human"), HumanMessageChunk(content="human"),

View File

@ -388,7 +388,6 @@ def test_base_tool_inheritance_base_schema() -> None:
def test_tool_lambda_args_schema() -> None: def test_tool_lambda_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function.""" """Test args schema inference when the tool argument is a lambda function."""
tool = Tool( tool = Tool(
name="tool", name="tool",
description="A tool", description="A tool",
@ -403,7 +402,7 @@ def test_structured_tool_from_function_docstring() -> None:
"""Test that structured tools can be created from functions.""" """Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str: def foo(bar: int, baz: str) -> str:
"""Docstring """Docstring.
Args: Args:
bar: the bar value bar: the bar value
@ -437,7 +436,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
"""Test that structured tools can be created from functions.""" """Test that structured tools can be created from functions."""
def foo(bar: int, baz: list[str]) -> str: def foo(bar: int, baz: list[str]) -> str:
"""Docstring """Docstring.
Args: Args:
bar: int bar: int
@ -526,7 +525,7 @@ def test_tool_from_function_with_run_manager() -> None:
def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str: def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str:
"""Docstring """Docstring
Args: Args:
bar: str bar: str.
""" """
assert callbacks is not None assert callbacks is not None
return "foo" + bar return "foo" + bar
@ -544,7 +543,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
def foo( def foo(
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str: ) -> str:
"""Docstring """Docstring.
Args: Args:
bar: int bar: int
@ -1381,7 +1380,7 @@ class _MockStructuredToolWithRawOutput(BaseTool):
def _mock_structured_tool_with_artifact( def _mock_structured_tool_with_artifact(
arg1: int, arg2: bool, arg3: Optional[dict] = None arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> tuple[str, dict]: ) -> tuple[str, dict]:
"""A Structured Tool""" """A Structured Tool."""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@ -1891,7 +1890,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
from langchain_core.tools import StructuredTool from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str: def foo(a: int, b: str) -> str:
"""Hahaha""" """Hahaha."""
return "foo" return "foo"
foo_tool = StructuredTool.from_function( foo_tool = StructuredTool.from_function(
@ -2187,11 +2186,11 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
@tool(args_schema=Foo) @tool(args_schema=Foo)
def foo(x): # type: ignore[no-untyped-def] def foo(x): # type: ignore[no-untyped-def]
"""foo""" """Foo."""
return x return x
assert foo.tool_call_schema.model_json_schema() == { assert foo.tool_call_schema.model_json_schema() == {
"description": "foo", "description": "Foo.",
"properties": { "properties": {
"x": { "x": {
"description": "List of integers", "description": "List of integers",
@ -2269,7 +2268,7 @@ def test_injected_arg_with_complex_type() -> None:
def test_tool_injected_tool_call_id() -> None: def test_tool_injected_tool_call_id() -> None:
@tool @tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage: def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
"""foo""" """Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo.invoke( assert foo.invoke(
@ -2281,7 +2280,7 @@ def test_tool_injected_tool_call_id() -> None:
@tool @tool
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage: def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
"""foo""" """Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo2.invoke( assert foo2.invoke(
@ -2292,7 +2291,7 @@ def test_tool_injected_tool_call_id() -> None:
def test_tool_uninjected_tool_call_id() -> None: def test_tool_uninjected_tool_call_id() -> None:
@tool @tool
def foo(x: int, tool_call_id: str) -> ToolMessage: def foo(x: int, tool_call_id: str) -> ToolMessage:
"""foo""" """Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -83,7 +83,6 @@ def test_tracer_with_run_tree_parent() -> None:
def test_log_lock() -> None: def test_log_lock() -> None:
"""Test that example assigned at callback start/end is honored.""" """Test that example assigned at callback start/end is honored."""
client = unittest.mock.MagicMock(spec=Client) client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client) tracer = LangChainTracer(client=client)
@ -96,9 +95,7 @@ def test_log_lock() -> None:
class LangChainProjectNameTest(unittest.TestCase): class LangChainProjectNameTest(unittest.TestCase):
""" """Test that the project name is set correctly for runs."""
Test that the project name is set correctly for runs.
"""
class SetProperTracerProjectTestCase: class SetProperTracerProjectTestCase:
def __init__( def __init__(

View File

@ -39,7 +39,7 @@ from langchain_core.utils.function_calling import (
@pytest.fixture() @pytest.fixture()
def pydantic() -> type[BaseModel]: def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801 class dummy_function(BaseModel): # noqa: N801
"""dummy function""" """Dummy function."""
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
@ -53,7 +53,7 @@ def annotated_function() -> Callable:
arg1: ExtensionsAnnotated[int, "foo"], arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"], arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
) -> None: ) -> None:
"""dummy function""" """Dummy function."""
return dummy_function return dummy_function
@ -61,7 +61,7 @@ def annotated_function() -> Callable:
@pytest.fixture() @pytest.fixture()
def function() -> Callable: def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function """Dummy function.
Args: Args:
arg1: foo arg1: foo
@ -74,7 +74,7 @@ def function() -> Callable:
@pytest.fixture() @pytest.fixture()
def function_docstring_annotations() -> Callable: def function_docstring_annotations() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function """Dummy function.
Args: Args:
arg1 (int): foo arg1 (int): foo
@ -105,7 +105,7 @@ def dummy_tool() -> BaseTool:
class DummyFunction(BaseTool): class DummyFunction(BaseTool):
args_schema: type[BaseModel] = Schema args_schema: type[BaseModel] = Schema
name: str = "dummy_function" name: str = "dummy_function"
description: str = "dummy function" description: str = "Dummy function."
def _run(self, *args: Any, **kwargs: Any) -> Any: def _run(self, *args: Any, **kwargs: Any) -> Any:
pass pass
@ -122,7 +122,7 @@ def dummy_structured_tool() -> StructuredTool:
return StructuredTool.from_function( return StructuredTool.from_function(
lambda x: None, lambda x: None,
name="dummy_function", name="dummy_function",
description="dummy function", description="Dummy function.",
args_schema=Schema, args_schema=Schema,
) )
@ -130,7 +130,7 @@ def dummy_structured_tool() -> StructuredTool:
@pytest.fixture() @pytest.fixture()
def dummy_pydantic() -> type[BaseModel]: def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801 class dummy_function(BaseModel): # noqa: N801
"""dummy function""" """Dummy function."""
arg1: int = Field(..., description="foo") arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'") arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
@ -141,7 +141,7 @@ def dummy_pydantic() -> type[BaseModel]:
@pytest.fixture() @pytest.fixture()
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]: def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): # noqa: N801 class dummy_function(BaseModelV2Maybe): # noqa: N801
"""dummy function""" """Dummy function."""
arg1: int = FieldV2Maybe(..., description="foo") arg1: int = FieldV2Maybe(..., description="foo")
arg2: Literal["bar", "baz"] = FieldV2Maybe( arg2: Literal["bar", "baz"] = FieldV2Maybe(
@ -154,7 +154,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
@pytest.fixture() @pytest.fixture()
def dummy_typing_typed_dict() -> type: def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): # noqa: N801 class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function""" """Dummy function."""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821 arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722 arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722
@ -165,7 +165,7 @@ def dummy_typing_typed_dict() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_typing_typed_dict_docstring() -> type: def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): # noqa: N801 class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function """Dummy function.
Args: Args:
arg1: foo arg1: foo
@ -181,7 +181,7 @@ def dummy_typing_typed_dict_docstring() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_extensions_typed_dict() -> type: def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801 class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function""" """Dummy function."""
arg1: ExtensionsAnnotated[int, ..., "foo"] arg1: ExtensionsAnnotated[int, ..., "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"]
@ -192,7 +192,7 @@ def dummy_extensions_typed_dict() -> type:
@pytest.fixture() @pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> type: def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801 class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function """Dummy function.
Args: Args:
arg1: foo arg1: foo
@ -209,7 +209,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
def json_schema() -> dict: def json_schema() -> dict:
return { return {
"title": "dummy_function", "title": "dummy_function",
"description": "dummy function", "description": "Dummy function.",
"type": "object", "type": "object",
"properties": { "properties": {
"arg1": {"description": "foo", "type": "integer"}, "arg1": {"description": "foo", "type": "integer"},
@ -227,7 +227,7 @@ def json_schema() -> dict:
def anthropic_tool() -> dict: def anthropic_tool() -> dict:
return { return {
"name": "dummy_function", "name": "dummy_function",
"description": "dummy function", "description": "Dummy function.",
"input_schema": { "input_schema": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -248,7 +248,7 @@ def bedrock_converse_tool() -> dict:
return { return {
"toolSpec": { "toolSpec": {
"name": "dummy_function", "name": "dummy_function",
"description": "dummy function", "description": "Dummy function.",
"inputSchema": { "inputSchema": {
"json": { "json": {
"type": "object", "type": "object",
@ -269,7 +269,7 @@ def bedrock_converse_tool() -> dict:
class Dummy: class Dummy:
def dummy_function(self, arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(self, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function """Dummy function.
Args: Args:
arg1: foo arg1: foo
@ -280,7 +280,7 @@ class Dummy:
class DummyWithClassMethod: class DummyWithClassMethod:
@classmethod @classmethod
def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None: def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function """Dummy function.
Args: Args:
arg1: foo arg1: foo
@ -307,7 +307,7 @@ def test_convert_to_openai_function(
) -> None: ) -> None:
expected = { expected = {
"name": "dummy_function", "name": "dummy_function",
"description": "dummy function", "description": "Dummy function.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -345,7 +345,7 @@ def test_convert_to_openai_function(
assert actual == expected assert actual == expected
# Test runnables # Test runnables
actual = convert_to_openai_function(runnable.as_tool(description="dummy function")) actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
parameters = { parameters = {
"type": "object", "type": "object",
"properties": { "properties": {
@ -392,7 +392,7 @@ def test_convert_to_openai_function_nested_v2() -> None:
) )
def my_function(arg1: NestedV2) -> None: def my_function(arg1: NestedV2) -> None:
"""dummy function""" """Dummy function."""
convert_to_openai_function(my_function) convert_to_openai_function(my_function)
@ -405,11 +405,11 @@ def test_convert_to_openai_function_nested() -> None:
) )
def my_function(arg1: Nested) -> None: def my_function(arg1: Nested) -> None:
"""dummy function""" """Dummy function."""
expected = { expected = {
"name": "my_function", "name": "my_function",
"description": "dummy function", "description": "Dummy function.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -442,11 +442,11 @@ def test_convert_to_openai_function_nested_strict() -> None:
) )
def my_function(arg1: Nested) -> None: def my_function(arg1: Nested) -> None:
"""dummy function""" """Dummy function."""
expected = { expected = {
"name": "my_function", "name": "my_function",
"description": "dummy function", "description": "Dummy function.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -608,7 +608,7 @@ def test_function_optional_param() -> None:
b: str, b: str,
c: Optional[list[Optional[str]]], c: Optional[list[Optional[str]]],
) -> None: ) -> None:
"""A test function""" """A test function."""
func = convert_to_openai_function(func5) func = convert_to_openai_function(func5)
req = func["parameters"]["required"] req = func["parameters"]["required"]
@ -617,7 +617,7 @@ def test_function_optional_param() -> None:
def test_function_no_params() -> None: def test_function_no_params() -> None:
def nullary_function() -> None: def nullary_function() -> None:
"""nullary function""" """Nullary function."""
func = convert_to_openai_function(nullary_function) func = convert_to_openai_function(nullary_function)
req = func["parameters"].get("required") req = func["parameters"].get("required")
@ -722,12 +722,12 @@ def test__convert_typed_dict_to_openai_function(
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
class SubTool(typed_dict): class SubTool(typed_dict):
"""Subtool docstring""" """Subtool docstring."""
args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(typed_dict): class Tool(typed_dict):
"""Docstring """Docstring.
Args: Args:
arg1: foo arg1: foo
@ -753,7 +753,7 @@ def test__convert_typed_dict_to_openai_function(
expected = { expected = {
"name": "Tool", "name": "Tool",
"description": "Docstring", "description": "Docstring.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -768,7 +768,7 @@ def test__convert_typed_dict_to_openai_function(
"arg3": { "arg3": {
"type": "array", "type": "array",
"items": { "items": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -798,7 +798,7 @@ def test__convert_typed_dict_to_openai_function(
{"type": "array", "items": {}}, {"type": "array", "items": {}},
{ {
"title": "SubTool", "title": "SubTool",
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -816,7 +816,7 @@ def test__convert_typed_dict_to_openai_function(
"arg7": { "arg7": {
"type": "array", "type": "array",
"items": { "items": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -834,7 +834,7 @@ def test__convert_typed_dict_to_openai_function(
"items": [ "items": [
{ {
"title": "SubTool", "title": "SubTool",
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -850,7 +850,7 @@ def test__convert_typed_dict_to_openai_function(
"arg9": { "arg9": {
"type": "array", "type": "array",
"items": { "items": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -864,7 +864,7 @@ def test__convert_typed_dict_to_openai_function(
"arg10": { "arg10": {
"type": "array", "type": "array",
"items": { "items": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -878,7 +878,7 @@ def test__convert_typed_dict_to_openai_function(
"arg11": { "arg11": {
"type": "array", "type": "array",
"items": { "items": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -893,7 +893,7 @@ def test__convert_typed_dict_to_openai_function(
"arg12": { "arg12": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -907,7 +907,7 @@ def test__convert_typed_dict_to_openai_function(
"arg13": { "arg13": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -921,7 +921,7 @@ def test__convert_typed_dict_to_openai_function(
"arg14": { "arg14": {
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"description": "Subtool docstring", "description": "Subtool docstring.",
"type": "object", "type": "object",
"properties": { "properties": {
"args": { "args": {
@ -981,13 +981,13 @@ def test_convert_union_type_py_39() -> None:
def test_convert_to_openai_function_no_args() -> None: def test_convert_to_openai_function_no_args() -> None:
@tool @tool
def empty_tool() -> str: def empty_tool() -> str:
"""No args""" """No args."""
return "foo" return "foo"
actual = convert_to_openai_function(empty_tool, strict=True) actual = convert_to_openai_function(empty_tool, strict=True)
assert actual == { assert actual == {
"name": "empty_tool", "name": "empty_tool",
"description": "No args", "description": "No args.",
"parameters": { "parameters": {
"properties": {}, "properties": {},
"additionalProperties": False, "additionalProperties": False,

View File

@ -140,7 +140,7 @@ def test_is_basemodel_instance() -> None:
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2") @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None: def test_with_field_metadata() -> None:
"""Test pydantic with field metadata""" """Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2 from pydantic import BaseModel as BaseModelV2
from pydantic import Field as FieldV2 from pydantic import Field as FieldV2
@ -202,7 +202,6 @@ def test_fields_pydantic_v1_from_2() -> None:
def test_create_model_v2() -> None: def test_create_model_v2() -> None:
"""Test that create model v2 works as expected.""" """Test that create model v2 works as expected."""
with warnings.catch_warnings(record=True) as record: with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always") # Cause all warnings to always be triggered warnings.simplefilter("always") # Cause all warnings to always be triggered
foo = create_model_v2("Foo", field_definitions={"a": (int, None)}) foo = create_model_v2("Foo", field_definitions={"a": (int, None)})

View File

@ -35,7 +35,7 @@ async def test_inmemory_similarity_search() -> None:
async def test_inmemory_similarity_search_with_score() -> None: async def test_inmemory_similarity_search_with_score() -> None:
"""Test end to end similarity search with score""" """Test end to end similarity search with score."""
store = await InMemoryVectorStore.afrom_texts( store = await InMemoryVectorStore.afrom_texts(
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3) ["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
) )
@ -63,7 +63,7 @@ async def test_add_by_ids() -> None:
async def test_inmemory_mmr() -> None: async def test_inmemory_mmr() -> None:
"""Test MMR search""" """Test MMR search."""
texts = ["foo", "foo", "fou", "foy"] texts = ["foo", "foo", "fou", "foy"]
docsearch = await InMemoryVectorStore.afrom_texts( docsearch = await InMemoryVectorStore.afrom_texts(
texts, DeterministicFakeEmbedding(size=6) texts, DeterministicFakeEmbedding(size=6)
@ -147,7 +147,6 @@ async def test_inmemory_upsert() -> None:
async def test_inmemory_get_by_ids() -> None: async def test_inmemory_get_by_ids() -> None:
"""Test get by ids.""" """Test get by ids."""
store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3)) store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3))
store.upsert( store.upsert(

View File

@ -117,7 +117,6 @@ def test_default_add_documents(vs_class: type[VectorStore]) -> None:
"""Test that we can implement the upsert method of the CustomVectorStore """Test that we can implement the upsert method of the CustomVectorStore
class without violating the Liskov Substitution Principle. class without violating the Liskov Substitution Principle.
""" """
store = vs_class() store = vs_class()
# Check upsert with id # Check upsert with id