core: Auto-fix some docstrings (#29337)

This commit is contained in:
Christophe Bornet 2025-01-21 19:29:53 +01:00 committed by GitHub
parent 86a0720310
commit 1c4ce7b42b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 364 additions and 457 deletions

View File

@ -50,7 +50,7 @@ def beta(
``@beta`` would mess up ``__init__`` inheritance when installing its
own (annotation-emitting) ``C.__init__``).
Arguments:
Args:
message : str, optional
Override the default beta message. The %(since)s,
%(name)s, %(alternative)s, %(obj_type)s, %(addendum)s,
@ -63,8 +63,7 @@ def beta(
addendum : str, optional
Additional text appended directly to the final message.
Examples
--------
Examples:
.. code-block:: python

View File

@ -95,7 +95,7 @@ def deprecated(
defaults to 'class' if decorating a class, 'attribute' if decorating a
property, and 'function' otherwise.
Arguments:
Args:
since : str
The release at which this API became deprecated.
message : str, optional
@ -122,8 +122,7 @@ def deprecated(
since. Set to other Falsy values to not schedule a removal
date. Cannot be used together with pending.
Examples
--------
Examples:
.. code-block:: python
@ -183,7 +182,6 @@ def deprecated(
async def awarning_emitting_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Same as warning_emitting_wrapper, but for async functions."""
nonlocal warned
if not warned and not is_caller_internal():
warned = True

View File

@ -74,7 +74,8 @@ class AgentAction(Serializable):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "agent"]."""
Default is ["langchain", "schema", "agent"].
"""
return ["langchain", "schema", "agent"]
@property
@ -189,7 +190,6 @@ def _convert_agent_observation_to_messages(
Returns:
AIMessage that corresponds to the original tool invocation.
"""
if isinstance(agent_action, AgentActionMessageLog):
return [_create_function_message(agent_action, observation)]
else:

View File

@ -307,8 +307,7 @@ class ContextSet(RunnableSerializable):
class Context:
"""
Context for a runnable.
"""Context for a runnable.
The `Context` class provides methods for creating context scopes,
getters, and setters within a runnable. It allows for managing

View File

@ -131,7 +131,8 @@ class ChainManagerMixin:
outputs (Dict[str, Any]): The outputs of the chain.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
def on_chain_error(
self,
@ -147,7 +148,8 @@ class ChainManagerMixin:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
def on_agent_action(
self,
@ -163,7 +165,8 @@ class ChainManagerMixin:
action (AgentAction): The agent action.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
def on_agent_finish(
self,
@ -179,7 +182,8 @@ class ChainManagerMixin:
finish (AgentFinish): The agent finish.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
class ToolManagerMixin:
@ -199,7 +203,8 @@ class ToolManagerMixin:
output (Any): The output of the tool.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
def on_tool_error(
self,
@ -215,7 +220,8 @@ class ToolManagerMixin:
error (BaseException): The error that occurred.
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
class CallbackManagerMixin:
@ -824,7 +830,8 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id (UUID): The run ID. This is the ID of the current run.
parent_run_id (UUID): The parent run ID. This is the ID of the parent run.
tags (Optional[List[str]]): The tags.
kwargs (Any): Additional keyword arguments."""
kwargs (Any): Additional keyword arguments.
"""
async def on_retriever_error(
self,

View File

@ -164,6 +164,7 @@ async def atrace_as_chain_group(
Defaults to None.
metadata (Dict[str, Any], optional): The metadata to apply to all runs.
Defaults to None.
Returns:
AsyncCallbackManager: The async callback manager for the chain group.
@ -216,8 +217,7 @@ Func = TypeVar("Func", bound=Callable)
def shielded(func: Func) -> Func:
"""
Makes so an awaitable method is always shielded from cancellation.
"""Makes so an awaitable method is always shielded from cancellation.
Args:
func (Callable): The function to shield.
@ -1310,7 +1310,6 @@ class CallbackManager(BaseCallbackManager):
List[CallbackManagerForLLMRun]: A callback manager for each
list of messages as an LLM run.
"""
managers = []
for message_list in messages:
if run_id is not None:
@ -1729,7 +1728,6 @@ class AsyncCallbackManager(BaseCallbackManager):
callback managers, one for each LLM Run corresponding
to each prompt.
"""
inline_tasks = []
non_inline_tasks = []
inline_handlers = [handler for handler in self.handlers if handler.run_inline]

View File

@ -1,6 +1,5 @@
"""**Chat message history** stores a history of the message interactions in a chat.
**Class hierarchy:**
.. code-block::
@ -187,10 +186,10 @@ class BaseChatMessageHistory(ABC):
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""
"""Remove all messages from the store."""
async def aclear(self) -> None:
"""Async remove all messages from the store"""
"""Async remove all messages from the store."""
from langchain_core.runnables.config import run_in_executor
await run_in_executor(None, self.clear)

View File

@ -8,7 +8,8 @@ from langchain_core.messages import BaseMessage
class ChatSession(TypedDict, total=False):
"""Chat Session represents a single
conversation, channel, or other group of messages."""
conversation, channel, or other group of messages.
"""
messages: Sequence[BaseMessage]
"""A sequence of the LangChain chat messages loaded from the source."""

View File

@ -48,7 +48,6 @@ class BaseLoader(ABC): # noqa: B024
Returns:
List of Documents.
"""
if text_splitter is None:
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter

View File

@ -15,15 +15,16 @@ class BaseExampleSelector(ABC):
Args:
example: A dictionary with keys as input variables
and values as their values."""
and values as their values.
"""
async def aadd_example(self, example: dict[str, str]) -> Any:
"""Async add new example to store.
Args:
example: A dictionary with keys as input variables
and values as their values."""
and values as their values.
"""
return await run_in_executor(None, self.add_example, example)
@abstractmethod
@ -32,13 +33,14 @@ class BaseExampleSelector(ABC):
Args:
input_variables: A dictionary with keys as input variables
and values as their values."""
and values as their values.
"""
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
"""Async select which examples to use based on the inputs.
Args:
input_variables: A dictionary with keys as input variables
and values as their values."""
and values as their values.
"""
return await run_in_executor(None, self.select_examples, input_variables)

View File

@ -50,7 +50,6 @@ class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
example: A dictionary with keys as input variables
and values as their values.
"""
self.add_example(example)
@model_validator(mode="after")

View File

@ -241,7 +241,7 @@ def index(
For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document.
IMPORTANT:
Important:
* In full mode, the loader should be returning
the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not
@ -546,7 +546,7 @@ async def aindex(
For the time being, documents are indexed using their hashes, and users
are not able to specify the uid of the document.
IMPORTANT:
Important:
* In full mode, the loader should be returning
the entire dataset, and not just a subset of the dataset.
Otherwise, the auto_cleanup will remove documents that it is not
@ -614,7 +614,6 @@ async def aindex(
* Added `scoped_full` cleanup mode.
"""
if cleanup not in {"incremental", "full", "scoped_full", None}:
msg = (
f"cleanup should be one of 'incremental', 'full', 'scoped_full' or None. "

View File

@ -288,7 +288,6 @@ class InMemoryRecordManager(RecordManager):
ids.
ValueError: If time_at_least is in the future.
"""
if group_ids and len(keys) != len(group_ids):
msg = "Length of keys must match length of group_ids"
raise ValueError(msg)

View File

@ -84,7 +84,6 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
Returns:
ChatResult: Chat result.
"""
generation = next(stream, None)
if generation:
generation += list(stream)
@ -112,7 +111,6 @@ async def agenerate_from_stream(
Returns:
ChatResult: Chat result.
"""
chunks = [chunk async for chunk in stream]
return await run_in_executor(None, generate_from_stream, iter(chunks))
@ -521,7 +519,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.startswith("Chat"):
@ -955,7 +952,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
async def _agenerate(
self,
@ -964,7 +961,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
return await run_in_executor(
None,
self._generate,

View File

@ -42,7 +42,7 @@ class FakeListLLM(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return next response"""
"""Return next response."""
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
@ -57,7 +57,7 @@ class FakeListLLM(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return next response"""
"""Return next response."""
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1

View File

@ -220,7 +220,7 @@ class GenericFakeChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
message = next(self.messages)
message_ = AIMessage(content=message) if isinstance(message, str) else message
generation = ChatGeneration(message=message_)
@ -342,7 +342,7 @@ class ParrotFakeChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
@property

View File

@ -91,7 +91,6 @@ def create_base_retry_decorator(
Raises:
ValueError: If the cache is not set and cache is True.
"""
_logging = before_sleep_log(logger, logging.WARNING)
def _before_sleep(retry_state: RetryCallState) -> None:
@ -278,7 +277,6 @@ async def aupdate_cache(
Raises:
ValueError: If the cache is not set and cache is True.
"""
llm_cache = _resolve_cache(cache)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
@ -292,7 +290,8 @@ async def aupdate_cache(
class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.
It should take in a prompt and return a string."""
It should take in a prompt and return a string.
"""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""
@ -346,7 +345,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> LangSmithParams:
"""Get standard params for tracing."""
# get default provider from class name
default_provider = self.__class__.__name__
if default_provider.endswith("LLM"):

View File

@ -1,5 +1,4 @@
"""
This file contains a mapping between the lc_namespace path for a given
"""This file contains a mapping between the lc_namespace path for a given
subclass that implements from Serializable to the namespace
where that class is actually located.

View File

@ -28,7 +28,8 @@ class FunctionMessage(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
@ -48,7 +49,8 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore

View File

@ -41,7 +41,8 @@ class HumanMessage(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __init__(
@ -72,5 +73,6 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]

View File

@ -28,7 +28,8 @@ class RemoveMessage(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]

View File

@ -35,7 +35,8 @@ class SystemMessage(BaseMessage):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
def __init__(
@ -66,5 +67,6 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]

View File

@ -89,7 +89,8 @@ class ToolMessage(BaseMessage, ToolOutputMixin):
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Default is ["langchain", "schema", "messages"]."""
Default is ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]
@model_validator(mode="before")

View File

@ -817,7 +817,6 @@ def trim_messages(
AIMessage( [{"type": "text", "text": "This is the FIRST 4 token block."}], id="second"),
]
""" # noqa: E501
if start_on and strategy == "first":
raise ValueError
if include_system and strategy == "first":

View File

@ -46,8 +46,8 @@ class ListOutputParser(BaseTransformOutputParser[list[str]]):
Args:
text: The output of an LLM call.
Returns:
A list of strings.
Returns:
A list of strings.
"""
def parse_iter(self, text: str) -> Iterator[re.Match]:
@ -135,7 +135,9 @@ class CommaSeparatedListOutputParser(ListOutputParser):
@classmethod
def is_lc_serializable(cls) -> bool:
"""Check if the langchain object is serializable.
Returns True."""
Returns True.
"""
return True
@classmethod

View File

@ -84,7 +84,6 @@ class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
if len(result) != 1:
msg = f"Expected exactly one result, but got {len(result)}"
raise OutputParserException(msg)
@ -189,7 +188,6 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
the provided schema.
Example:
... code-block:: python
message = AIMessage(

View File

@ -168,7 +168,6 @@ class JsonOutputToolsParser(BaseCumulativeTransformOutputParser[Any]):
Raises:
OutputParserException: If the output is not valid JSON.
"""
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."

View File

@ -129,7 +129,8 @@ class ImagePromptValue(PromptValue):
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
For use in external schemas.
"""
messages: Sequence[AnyMessage]
"""Sequence of messages."""

View File

@ -98,13 +98,15 @@ class BasePromptTemplate(
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"]."""
Returns ["langchain", "schema", "prompt_template"].
"""
return ["langchain", "schema", "prompt_template"]
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.
Returns True."""
Returns True.
"""
return True
model_config = ConfigDict(

View File

@ -54,7 +54,8 @@ class BaseMessagePromptTemplate(Serializable, ABC):
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable.
Returns: True"""
Returns: True.
"""
return True
@classmethod
@ -392,8 +393,7 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@property
def input_variables(self) -> list[str]:
"""
Input variables for this prompt template.
"""Input variables for this prompt template.
Returns:
List of input variable names.
@ -624,8 +624,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
@property
def input_variables(self) -> list[str]:
"""
Input variables for this prompt template.
"""Input variables for this prompt template.
Returns:
List of input variable names.
@ -742,8 +741,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
@property
def lc_attributes(self) -> dict:
"""
Return a list of attribute names that should be included in the
"""Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
@ -980,7 +978,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
A chat prompt template.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
@ -1173,7 +1170,6 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python

View File

@ -265,7 +265,6 @@ class FewShotChatMessagePromptTemplate(
to dynamically select examples based on the input.
Examples:
Prompt template with a fixed list of examples (matching the sample
conversation above):

View File

@ -184,8 +184,7 @@ def _load_prompt_from_file(
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load chat prompt from config"""
"""Load chat prompt from config."""
messages = config.pop("messages")
template = messages[0]["prompt"].pop("template") if messages else None
config.pop("input_variables")

View File

@ -23,8 +23,7 @@ def _get_inputs(inputs: dict, input_variables: list[str]) -> dict:
),
)
class PipelinePromptTemplate(BasePromptTemplate):
"""
This has been deprecated in favor of chaining individual prompts together in your
"""This has been deprecated in favor of chaining individual prompts together in your
code. E.g. using a for loop, you could do:
.. code-block:: python

View File

@ -284,7 +284,6 @@ class PromptTemplate(StringPromptTemplate):
Returns:
The prompt template loaded from the template.
"""
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}

View File

@ -64,8 +64,7 @@ def jinja2_formatter(template: str, /, **kwargs: Any) -> str:
def validate_jinja2(template: str, input_variables: list[str]) -> None:
"""
Validate that the input variables are valid for the template.
"""Validate that the input variables are valid for the template.
Issues a warning if missing or extra variables are found.
Args:

View File

@ -72,7 +72,6 @@ class StructuredPrompt(ChatPromptTemplate):
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python

View File

@ -199,7 +199,6 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"):
default_retriever_name = default_retriever_name[9:]
@ -342,6 +341,7 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
Args:
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents.
"""

View File

@ -483,7 +483,6 @@ class Runnable(Generic[Input, Output], ABC):
Returns:
A pydantic model that can be used to validate config.
"""
include = include or []
config_specs = self.config_specs
configurable = (
@ -817,8 +816,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any],
) -> Iterator[tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs,
yielding results as they complete."""
yielding results as they complete.
"""
if not inputs:
return
@ -949,7 +948,6 @@ class Runnable(Generic[Input, Output], ABC):
Yields:
A tuple of the index of the input and the output from the Runnable.
"""
if not inputs:
return
@ -981,8 +979,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""
Default implementation of stream, which calls invoke.
"""Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output.
Args:
@ -1001,8 +998,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""
Default implementation of astream, which calls ainvoke.
"""Default implementation of astream, which calls ainvoke.
Subclasses should override this method if they support streaming output.
Args:
@ -1064,8 +1060,7 @@ class Runnable(Generic[Input, Output], ABC):
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
"""
Stream all output from a Runnable, as reported to the callback system.
"""Stream all output from a Runnable, as reported to the callback system.
This includes all inner runs of LLMs, Retrievers, Tools, etc.
Output is streamed as Log objects, which include a list of
@ -1392,8 +1387,8 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""
Default implementation of transform, which buffers input and then calls stream.
"""Default implementation of transform, which buffers input and calls astream.
Subclasses should override this method if they can start producing output while
input is still being generated.
@ -1434,8 +1429,7 @@ class Runnable(Generic[Input, Output], ABC):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""
Default implementation of atransform, which buffers input and calls astream.
"""Default implementation of atransform, which buffers input and calls astream.
Subclasses should override this method if they can start producing output while
input is still being generated.
@ -1472,8 +1466,7 @@ class Runnable(Generic[Input, Output], ABC):
yield output
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""
Bind arguments to a Runnable, returning a new Runnable.
"""Bind arguments to a Runnable, returning a new Runnable.
Useful when a Runnable in a chain requires an argument that is not
in the output of the previous Runnable or included in the user input.
@ -1520,8 +1513,7 @@ class Runnable(Generic[Input, Output], ABC):
# Sadly Unpack is not well-supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""
Bind config to a Runnable, returning a new Runnable.
"""Bind config to a Runnable, returning a new Runnable.
Args:
config: The config to bind to the Runnable.
@ -1552,8 +1544,7 @@ class Runnable(Generic[Input, Output], ABC):
Union[Callable[[Run], None], Callable[[Run, RunnableConfig], None]]
] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
"""Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the Runnable starts running, with the Run object.
on_end: Called after the Runnable finishes running, with the Run object.
@ -1620,8 +1611,7 @@ class Runnable(Generic[Input, Output], ABC):
on_end: Optional[AsyncListener] = None,
on_error: Optional[AsyncListener] = None,
) -> Runnable[Input, Output]:
"""
Bind asynchronous lifecycle listeners to a Runnable, returning a new Runnable.
"""Bind async lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Asynchronously called before the Runnable starts running.
on_end: Asynchronously called after the Runnable finishes running.
@ -1711,8 +1701,7 @@ class Runnable(Generic[Input, Output], ABC):
input_type: Optional[type[Input]] = None,
output_type: Optional[type[Output]] = None,
) -> Runnable[Input, Output]:
"""
Bind input and output types to a Runnable, returning a new Runnable.
"""Bind input and output types to a Runnable, returning a new Runnable.
Args:
input_type: The input type to bind to the Runnable. Defaults to None.
@ -1799,8 +1788,7 @@ class Runnable(Generic[Input, Output], ABC):
)
def map(self) -> Runnable[list[Input], list[Output]]:
"""
Return a new Runnable that maps a list of inputs to a list of outputs,
"""Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
Returns:
@ -1906,7 +1894,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
with callbacks. Use this method to implement invoke() in subclasses.
"""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
@ -1955,7 +1944,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
with callbacks. Use this method to implement ainvoke() in subclasses.
"""
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
@ -2004,7 +1994,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any],
) -> list[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
with callbacks. Use this method to implement invoke() in subclasses.
"""
if not input:
return []
@ -2076,7 +2067,8 @@ class Runnable(Generic[Input, Output], ABC):
**kwargs: Optional[Any],
) -> list[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
with callbacks. Use this method to implement invoke() in subclasses.
"""
if not input:
return []
@ -2149,7 +2141,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
Use this to implement `stream()` or `transform()` in Runnable subclasses."""
Use this to implement `stream()` or `transform()` in Runnable subclasses.
"""
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
@ -2249,7 +2242,8 @@ class Runnable(Generic[Input, Output], ABC):
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
Use this to implement `astream()` or `atransform()` in Runnable subclasses."""
Use this to implement `astream()` or `atransform()` in Runnable subclasses.
"""
# Mixin that is used by both astream log and astream events implementation
from langchain_core.tracers._streaming import _StreamingCallbackHandler
@ -5601,7 +5595,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
- ``with_fallbacks``: Bind a fallback policy to the underlying Runnable.
Example:
`bind`: Bind kwargs to pass to the underlying Runnable when running it.
.. code-block:: python

View File

@ -116,7 +116,7 @@ var_child_runnable_config = ContextVar(
def _set_config_context(config: RunnableConfig) -> None:
"""Set the child Runnable config + tracing context
"""Set the child Runnable config + tracing context.
Args:
config (RunnableConfig): The config to set.

View File

@ -404,7 +404,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
"""Get a new RunnableConfigurableFields with the specified
configurable fields."""
configurable fields.
"""
return self.default.configurable_fields(**{**self.fields, **kwargs})
def _prepare(

View File

@ -137,7 +137,7 @@ class Branch(NamedTuple):
class CurveStyle(Enum):
"""Enum for different curve styles supported by Mermaid"""
"""Enum for different curve styles supported by Mermaid."""
BASIS = "basis"
BUMP_X = "bumpX"
@ -169,7 +169,7 @@ class NodeStyles:
class MermaidDrawMethod(Enum):
"""Enum for different draw methods supported by Mermaid"""
"""Enum for different draw methods supported by Mermaid."""
PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph
API = "api" # Uses Mermaid.INK API to render the graph
@ -306,7 +306,8 @@ class Graph:
def next_id(self) -> str:
"""Return a new unique node
identifier that can be used to add a node to the graph."""
identifier that can be used to add a node to the graph.
"""
return uuid4().hex
def add_node(
@ -422,7 +423,8 @@ class Graph:
def reid(self) -> Graph:
"""Return a new graph with all nodes re-identified,
using their unique, readable names where possible."""
using their unique, readable names where possible.
"""
node_name_to_ids = defaultdict(list)
for node in self.nodes.values():
node_name_to_ids[node.name].append(node.id)
@ -457,18 +459,21 @@ class Graph:
def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
When drawing the graph, this node would be the origin.
"""
return _first_node(self)
def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination."""
When drawing the graph, this node would be the destination.
"""
return _last_node(self)
def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
i.e., if removing it would not leave the graph without a "first" node."""
i.e., if removing it would not leave the graph without a "first" node.
"""
first_node = self.first_node()
if (
first_node
@ -479,7 +484,8 @@ class Graph:
def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
i.e., if removing it would not leave the graph without a "last" node."""
i.e., if removing it would not leave the graph without a "last" node.
"""
last_node = self.last_node()
if (
last_node
@ -634,7 +640,8 @@ def _first_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
Exclude nodes/sources with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the origin."""
When drawing the graph, this node would be the origin.
"""
targets = {edge.target for edge in graph.edges if edge.source not in exclude}
found: list[Node] = []
for node in graph.nodes.values():
@ -647,7 +654,8 @@ def _last_node(graph: Graph, exclude: Sequence[str] = ()) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
Exclude nodes/targets with ids in the exclude list.
If there is no such node, or there are multiple, return None.
When drawing the graph, this node would be the destination."""
When drawing the graph, this node would be the destination.
"""
sources = {edge.source for edge in graph.edges if edge.target not in exclude}
found: list[Node] = []
for node in graph.nodes.values():

View File

@ -1,5 +1,6 @@
"""Draws DAG in ASCII.
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py"""
Adapted from https://github.com/iterative/dvc/blob/main/dvc/dagascii.py.
"""
import math
import os
@ -239,7 +240,6 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
| 1 |
+---+
"""
# NOTE: coordinates might me negative, so we need to shift
# everything to the positive plane before we actually draw it.
xlist = []

View File

@ -132,7 +132,6 @@ class PngDrawer:
:param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
"""
try:
import pygraphviz as pgv # type: ignore[import]
except ImportError as exc:

View File

@ -43,7 +43,6 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
way to use it is through the `.with_retry()` method on all Runnables.
Example:
Here's an example that uses a RunnableLambda to raise an exception
.. code-block:: python

View File

@ -44,8 +44,7 @@ class RouterInput(TypedDict):
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
"""
Runnable that routes to a set of Runnables based on Input['key'].
"""Runnable that routes to a set of Runnables based on Input['key'].
Returns the output of the selected Runnable.
Parameters:

View File

@ -462,9 +462,7 @@ def indent_lines_after_first(text: str, prefix: str) -> str:
class AddableDict(dict[str, Any]):
"""
Dictionary that can be added to another dictionary.
"""
"""Dictionary that can be added to another dictionary."""
def __add__(self, other: AddableDict) -> AddableDict:
chunk = AddableDict(self)

View File

@ -162,7 +162,6 @@ class StructuredTool(BaseTool):
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
if func is not None:
source_function = func
elif coroutine is not None:

View File

@ -531,8 +531,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
"""Persist a run."""
async def _start_trace(self, run: Run) -> None:
"""
Start a trace for a run.
"""Start a trace for a run.
Starting a trace will run concurrently with each _on_[run_type]_start method.
No _on_[run_type]_start callback should depend on operations in _start_trace.
@ -541,8 +540,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
await self._on_run_create(run)
async def _end_trace(self, run: Run) -> None:
"""
End a trace for a run.
"""End a trace for a run.
Ending a trace will run concurrently with each _on_[run_type]_end method.
No _on_[run_type]_end callback should depend on operations in _end_trace.

View File

@ -40,8 +40,7 @@ SCHEMA_FORMAT_TYPE = Literal["original", "streaming_events"]
class _TracerCore(ABC):
"""
Abstract base class for tracers.
"""Abstract base class for tracers.
This class provides common methods, and reusable methods for tracers.
"""
@ -233,9 +232,7 @@ class _TracerCore(ABC):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""
Append token event to LLM run and return the run.
"""
"""Append token event to LLM run and return the run."""
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
event_kwargs: dict[str, Any] = {"token": token}
if chunk:

View File

@ -35,35 +35,33 @@ def wait_for_all_evaluators() -> None:
class EvaluatorCallbackHandler(BaseTracer):
"""Tracer that runs a run evaluator whenever a run is persisted.
Parameters
----------
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
client : LangSmith Client, optional
The LangSmith client instance to use for evaluating the runs.
If not specified, a new instance will be created.
example_id : Union[UUID, str], optional
The example ID to be associated with the runs.
project_name : str, optional
The LangSmith project name to be organize eval chain runs under.
Args:
evaluators : Sequence[RunEvaluator]
The run evaluators to apply to all top level runs.
client : LangSmith Client, optional
The LangSmith client instance to use for evaluating the runs.
If not specified, a new instance will be created.
example_id : Union[UUID, str], optional
The example ID to be associated with the runs.
project_name : str, optional
The LangSmith project name to be organize eval chain runs under.
Attributes
----------
example_id : Union[UUID, None]
The example ID associated with the runs.
client : Client
The LangSmith client instance used for evaluating the runs.
evaluators : Sequence[RunEvaluator]
The sequence of run evaluators to be executed.
executor : ThreadPoolExecutor
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
skip_unfinished : bool
Whether to skip runs that are not finished or raised
an error.
project_name : Optional[str]
The LangSmith project name to be organize eval chain runs under.
Attributes:
example_id : Union[UUID, None]
The example ID associated with the runs.
client : Client
The LangSmith client instance used for evaluating the runs.
evaluators : Sequence[RunEvaluator]
The sequence of run evaluators to be executed.
executor : ThreadPoolExecutor
The thread pool executor used for running the evaluators.
futures : Set[Future]
The set of futures representing the running evaluators.
skip_unfinished : bool
Whether to skip runs that are not finished or raised
an error.
project_name : Optional[str]
The LangSmith project name to be organize eval chain runs under.
"""
name: str = "evaluator_callback_handler"

View File

@ -253,9 +253,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Run:
"""
Append token event to LLM run and return the run.
"""
"""Append token event to LLM run and return the run."""
return super()._llm_run_with_token_event(
# Drop the chunk; we don't need to save it
token,

View File

@ -24,8 +24,7 @@ class RunCollectorCallbackHandler(BaseTracer):
def __init__(
self, example_id: Optional[Union[UUID, str]] = None, **kwargs: Any
) -> None:
"""
Initialize the RunCollectorCallbackHandler.
"""Initialize the RunCollectorCallbackHandler.
Parameters
----------
@ -41,8 +40,7 @@ class RunCollectorCallbackHandler(BaseTracer):
self.traced_runs: list[Run] = []
def _persist_run(self, run: Run) -> None:
"""
Persist a run by adding it to the traced_runs list.
"""Persist a run by adding it to the traced_runs list.
Parameters
----------

View File

@ -1,5 +1,4 @@
"""
**Utility functions** for LangChain.
"""**Utility functions** for LangChain.
These functions do not depend on any other LangChain module.
"""

View File

@ -1,7 +1,6 @@
"""
Adapted from
"""Adapted from
https://github.com/maxfischer2781/asyncstdlib/blob/master/asyncstdlib/itertools.py
MIT License
MIT License.
"""
from collections import deque
@ -54,7 +53,6 @@ def py_anext(
Raises:
TypeError: If the iterator is not an async iterator.
"""
try:
__anext__ = cast(
Callable[[AsyncIterator[T]], Awaitable[T]], type(iterator).__anext__
@ -147,8 +145,7 @@ async def tee_peer(
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``.
"""Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.

View File

@ -1,4 +1,4 @@
"""Methods for creating function specs in the style of OpenAI Functions"""
"""Methods for creating function specs in the style of OpenAI Functions."""
from __future__ import annotations
@ -342,6 +342,7 @@ def convert_to_openai_function(
strict: Optional[bool] = None,
) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI function.
Args:
function:
A dictionary, Pydantic BaseModel class, TypedDict class, a LangChain

View File

@ -70,6 +70,7 @@ def extract_sub_links(
exclude_prefixes: Exclude any URLs that start with one of these prefixes.
continue_on_failure: If True, continue if parsing a specific link raises an
exception. Otherwise, raise the exception.
Returns:
List[str]: sub links.
"""

View File

@ -83,8 +83,7 @@ def tee_peer(
class Tee(Generic[T]):
"""
Create ``n`` separate asynchronous iterators over ``iterable``
"""Create ``n`` separate asynchronous iterators over ``iterable``.
This splits a single ``iterable`` into multiple iterators, each providing
the same items in the same order.

View File

@ -18,11 +18,10 @@ def _replace_new_line(match: re.Match[str]) -> str:
def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
"""The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
(newlines in JSON must be double-escaped: `\\n`).
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()
@ -161,8 +160,7 @@ def _parse_json(
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
"""Parse a JSON string from a Markdown string and check that it
contains the expected keys.
Args:

View File

@ -105,7 +105,6 @@ def dereference_refs(
Returns:
The dereferenced schema object.
"""
full_schema = full_schema or schema_obj
skip_keys = (
skip_keys

View File

@ -1,6 +1,5 @@
"""
Adapted from https://github.com/noahmorrison/chevron
MIT License
"""Adapted from https://github.com/noahmorrison/chevron
MIT License.
"""
from __future__ import annotations
@ -48,7 +47,6 @@ def grab_literal(template: str, l_del: str) -> tuple[str, str]:
Returns:
Tuple[str, str]: The literal and the template.
"""
global _CURRENT_LINE
try:
@ -74,7 +72,6 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
Returns:
bool: Whether the tag could be a standalone.
"""
# If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone:
padding = literal.split("\n")[-1]
@ -98,7 +95,6 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
Returns:
bool: Whether the tag could be a standalone.
"""
# Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]:
on_newline = template.split("\n", 1)
@ -199,36 +195,25 @@ def tokenize(
using file-like objects. It also accepts a string containing
the template.
Arguments:
template -- a file-like object, or a string of a mustache template
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache)
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
Args:
template: a file-like object, or a string of a mustache template
def_ldel: The default left delimiter
("{{" by default, as in spec compliant mustache)
def_rdel: The default right delimiter
("}}" by default, as in spec compliant mustache)
Returns:
A generator of mustache tags in the form of a tuple
-- (tag_type, tag_key)
Where tag_type is one of:
* literal
* section
* inverted section
* end
* partial
* no escape
And tag_key is either the key or in the case of a literal tag,
the literal itself.
A generator of mustache tags in the form of a tuple (tag_type, tag_key)
Where tag_type is one of:
* literal
* section
* inverted section
* end
* partial
* no escape
And tag_key is either the key or in the case of a literal tag,
the literal itself.
"""
global _CURRENT_LINE, _LAST_TAG_LINE
_CURRENT_LINE = 1
_LAST_TAG_LINE = None
@ -329,8 +314,7 @@ def tokenize(
def _html_escape(string: str) -> str:
"""HTML escape all of these " & < >"""
"""HTML escape all of these " & < >."""
html_codes = {
'"': "&quot;",
"<": "&lt;",
@ -352,8 +336,7 @@ def _get_key(
def_ldel: str,
def_rdel: str,
) -> Any:
"""Get a key from the current scope"""
"""Get a key from the current scope."""
# If the key is a dot
if key == ".":
# Then just return the current scope
@ -410,7 +393,7 @@ def _get_key(
def _get_partial(name: str, partials_dict: Mapping[str, str]) -> str:
"""Load a partial"""
"""Load a partial."""
try:
# Maybe the partial is in the dictionary
return partials_dict[name]
@ -441,45 +424,31 @@ def render(
Renders a mustache template with a data scope and inline partial capability.
Arguments:
template -- A file-like object or a string containing the template.
data -- A python dictionary with your data scope.
partials_path -- The path to where your partials are stored.
If set to None, then partials won't be loaded from the file system
(defaults to '.').
partials_ext -- The extension that you want the parser to look for
(defaults to 'mustache').
partials_dict -- A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache
(defaults to {}).
padding -- This is for padding partials, and shouldn't be used
(but can be if you really want to).
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache).
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache).
scopes -- The list of scopes that get_key will look through.
warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data.
Args:
template: A file-like object or a string containing the template.
data: A python dictionary with your data scope.
partials_path: The path to where your partials are stored.
If set to None, then partials won't be loaded from the file system
(defaults to '.').
partials_ext: The extension that you want the parser to look for
(defaults to 'mustache').
partials_dict: A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache
(defaults to {}).
padding: This is for padding partials, and shouldn't be used
(but can be if you really want to).
def_ldel: The default left delimiter
("{{" by default, as in spec compliant mustache).
def_rdel: The default right delimiter
("}}" by default, as in spec compliant mustache).
scopes: The list of scopes that get_key will look through.
warn: Log a warning when a template substitution isn't found in the data
keep: Keep unreplaced tags when a substitution isn't found in the data.
Returns:
A string containing the rendered template.
A string containing the rendered template.
"""
# If the template is a sequence but not derived from a string
if isinstance(template, Sequence) and not isinstance(template, str):
# Then we don't need to tokenize it

View File

@ -172,7 +172,6 @@ def pre_init(func: Callable) -> Any:
Returns:
Any: The decorated function.
"""
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)

View File

@ -20,7 +20,7 @@ from langchain_core.utils.pydantic import (
def xor_args(*arg_groups: tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."
"""Validate specified keyword args are mutually exclusive.".
Args:
*arg_groups (Tuple[str, ...]): Groups of mutually exclusive keyword args.

View File

@ -132,7 +132,6 @@ class VectorStore(ABC):
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
"""
msg = "delete method must be implemented by subclass."
raise NotImplementedError(msg)
@ -423,7 +422,6 @@ class VectorStore(ABC):
@staticmethod
def _cosine_relevance_score_fn(distance: float) -> float:
"""Normalize the distance to a score on a scale [0, 1]."""
return 1.0 - distance
@staticmethod
@ -435,8 +433,7 @@ class VectorStore(ABC):
return -1.0 * distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
"""The 'correct' relevance function
may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
@ -473,7 +470,6 @@ class VectorStore(ABC):
Returns:
List of Tuples of (doc, similarity_score).
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
@ -487,8 +483,7 @@ class VectorStore(ABC):
k: int = 4,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""
Default similarity search with relevance scores. Modify if necessary
"""Default similarity search with relevance scores. Modify if necessary
in subclass.
Return docs and relevance scores in the range [0, 1].
@ -514,8 +509,7 @@ class VectorStore(ABC):
k: int = 4,
**kwargs: Any,
) -> list[tuple[Document, float]]:
"""
Default similarity search with relevance scores. Modify if necessary
"""Default similarity search with relevance scores. Modify if necessary
in subclass.
Return docs and relevance scores in the range [0, 1].
@ -644,7 +638,6 @@ class VectorStore(ABC):
Returns:
List of Documents most similar to the query.
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
@ -678,7 +671,6 @@ class VectorStore(ABC):
Returns:
List of Documents most similar to the query vector.
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
@ -741,7 +733,6 @@ class VectorStore(ABC):
Returns:
List of Documents selected by maximal marginal relevance.
"""
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
@ -1056,7 +1047,6 @@ class VectorStoreRetriever(BaseRetriever):
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""
_kwargs = self.search_kwargs | kwargs
ls_params = super()._get_ls_params(**_kwargs)

View File

@ -55,46 +55,46 @@ def test_warn_beta(kwargs: dict[str, Any], expected_message: str) -> None:
@beta()
def beta_function() -> str:
"""original doc"""
"""Original doc."""
return "This is a beta function."
@beta()
async def beta_async_function() -> str:
"""original doc"""
"""Original doc."""
return "This is a beta async function."
class ClassWithBetaMethods:
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@beta()
def beta_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta method."
@beta()
async def beta_async_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta async method."
@classmethod
@beta()
def beta_classmethod(cls) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta classmethod."
@staticmethod
@beta()
def beta_staticmethod() -> str:
"""original doc"""
"""Original doc."""
return "This is a beta staticmethod."
@property
@beta()
def beta_property(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta property."
@ -240,11 +240,11 @@ def test_whole_class_beta() -> None:
@beta()
class BetaClass:
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@beta()
def beta_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta method."
with warnings.catch_warnings(record=True) as warning_list:
@ -281,14 +281,14 @@ def test_whole_class_inherited_beta() -> None:
class BetaClass:
@beta()
def beta_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta method."
@beta()
class InheritedBetaClass(BetaClass):
@beta()
def beta_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta method 2."
with warnings.catch_warnings(record=True) as warning_list:
@ -344,7 +344,7 @@ def test_whole_class_inherited_beta() -> None:
class MyModel(BaseModel):
@beta()
def beta_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a beta method."

View File

@ -75,46 +75,46 @@ def test_undefined_deprecation_schedule() -> None:
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
def deprecated_function() -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated function."
@deprecated(since="2.0.0", removal="3.0.0", pending=False)
async def deprecated_async_function() -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated async function."
class ClassWithDeprecatedMethods:
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated method."
@deprecated(since="2.0.0", removal="3.0.0")
async def deprecated_async_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated async method."
@classmethod
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_classmethod(cls) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated classmethod."
@staticmethod
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_staticmethod() -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated staticmethod."
@property
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_property(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated property."
@ -264,11 +264,11 @@ def test_whole_class_deprecation() -> None:
@deprecated(since="2.0.0", removal="3.0.0")
class DeprecatedClass:
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated method."
with warnings.catch_warnings(record=True) as warning_list:
@ -306,11 +306,11 @@ def test_whole_class_inherited_deprecation() -> None:
@deprecated(since="2.0.0", removal="3.0.0")
class DeprecatedClass:
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated method."
@deprecated(since="2.2.0", removal="3.2.0")
@ -318,11 +318,11 @@ def test_whole_class_inherited_deprecation() -> None:
"""Inherited deprecated class."""
def __init__(self) -> None:
"""original doc"""
"""Original doc."""
@deprecated(since="2.2.0", removal="3.2.0")
def deprecated_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated method."
with warnings.catch_warnings(record=True) as warning_list:
@ -379,7 +379,7 @@ def test_whole_class_inherited_deprecation() -> None:
class MyModel(BaseModel):
@deprecated(since="2.0.0", removal="3.0.0")
def deprecated_method(self) -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated method."
@ -408,7 +408,7 @@ def test_raise_error_for_bad_decorator() -> None:
@deprecated(since="2.0.0", alternative="NewClass", alternative_import="hello")
def deprecated_function() -> str:
"""original doc"""
"""Original doc."""
return "This is a deprecated function."
@ -417,7 +417,7 @@ def test_rename_parameter() -> None:
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
def foo(new_name: str) -> str:
"""original doc"""
"""Original doc."""
return new_name
with warnings.catch_warnings(record=True) as warning_list:
@ -427,7 +427,7 @@ def test_rename_parameter() -> None:
assert foo(new_name="hello") == "hello"
assert foo("hello") == "hello"
assert foo.__doc__ == "original doc"
assert foo.__doc__ == "Original doc."
with pytest.raises(TypeError):
foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):
@ -442,7 +442,7 @@ async def test_rename_parameter_for_async_func() -> None:
@rename_parameter(since="2.0.0", removal="3.0.0", old="old_name", new="new_name")
async def foo(new_name: str) -> str:
"""original doc"""
"""Original doc."""
return new_name
with warnings.catch_warnings(record=True) as warning_list:
@ -451,7 +451,7 @@ async def test_rename_parameter_for_async_func() -> None:
assert len(warning_list) == 1
assert await foo(new_name="hello") == "hello"
assert await foo("hello") == "hello"
assert foo.__doc__ == "original doc"
assert foo.__doc__ == "Original doc."
with pytest.raises(TypeError):
await foo(meow="hello") # type: ignore[call-arg]
with pytest.raises(TypeError):

View File

@ -91,7 +91,6 @@ async def test_async_custom_event_implicit_config() -> None:
async def test_async_callback_manager() -> None:
"""Test async callback manager."""
callback = AsyncCustomCallbackHandler()
run_id = uuid.UUID(int=7)

View File

@ -2,8 +2,7 @@ from langchain_core.embeddings import DeterministicFakeEmbedding
def test_deterministic_fake_embeddings() -> None:
"""
Test that the deterministic fake embeddings return the same
"""Test that the deterministic fake embeddings return the same
embedding vector for the same text.
"""
fake = DeterministicFakeEmbedding(size=10)

View File

@ -1,4 +1,4 @@
"""Test in memory indexer"""
"""Test in memory indexer."""
from collections.abc import AsyncGenerator, Generator

View File

@ -1167,7 +1167,7 @@ def test_incremental_delete_with_same_source(
def test_incremental_indexing_with_batch_size(
record_manager: InMemoryRecordManager, vector_store: InMemoryVectorStore
) -> None:
"""Test indexing with incremental indexing"""
"""Test indexing with incremental indexing."""
loader = ToyLoader(
documents=[
Document(
@ -2031,7 +2031,6 @@ def test_index_with_upsert_kwargs_for_document_indexer(
mocker: MockerFixture,
) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "upsert")
docs = [
@ -2070,7 +2069,6 @@ async def test_aindex_with_upsert_kwargs_for_document_indexer(
mocker: MockerFixture,
) -> None:
"""Test that kwargs are passed to the upsert method of the document indexer."""
document_index = InMemoryDocumentIndex()
upsert_spy = mocker.spy(document_index.__class__, "aupsert")
docs = [

View File

@ -136,7 +136,7 @@ async def test_astream_fallback_to_ainvoke() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
message = AIMessage(content="hello")
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@ -164,7 +164,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
raise NotImplementedError
def _stream(
@ -209,7 +209,7 @@ async def test_astream_implementation_uses_astream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
raise NotImplementedError
async def _astream( # type: ignore
@ -243,7 +243,6 @@ class FakeTracer(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.traced_run_ids.append(run.id)

View File

@ -8,7 +8,6 @@ from langchain_core.rate_limiters import InMemoryRateLimiter
def test_rate_limit_invoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world"]),
rate_limiter=InMemoryRateLimiter(
@ -35,7 +34,6 @@ def test_rate_limit_invoke() -> None:
async def test_rate_limit_ainvoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(

View File

@ -160,7 +160,7 @@ async def test_astream_implementation_fallback_to_stream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
"""Top Level call."""
raise NotImplementedError
def _stream(
@ -197,7 +197,7 @@ async def test_astream_implementation_uses_astream() -> None:
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
"""Top Level call."""
raise NotImplementedError
async def _astream(

View File

@ -29,7 +29,8 @@ def test_single_item() -> None:
def test_multiple_items_with_spaces() -> None:
"""Test that a string with multiple comma-separated items
with spaces is parsed to a list."""
with spaces is parsed to a list.
"""
parser = CommaSeparatedListOutputParser()
text = "foo, bar, baz"
expected = ["foo", "bar", "baz"]
@ -66,7 +67,8 @@ def test_multiple_items() -> None:
def test_multiple_items_with_comma() -> None:
"""Test that a string with multiple comma-separated items with 1 item containing a
comma is parsed to a list."""
comma is parsed to a list.
"""
parser = CommaSeparatedListOutputParser()
text = '"foo, foo2",bar,baz'
expected = ["foo, foo2", "bar", "baz"]

View File

@ -166,7 +166,6 @@ def test_pydantic_output_functions_parser() -> None:
def test_pydantic_output_functions_parser_multiple_schemas() -> None:
"""Test that the parser works if providing multiple pydantic schemas."""
message = AIMessage(
content="This is a test message",
additional_kwargs={

View File

@ -482,7 +482,7 @@ class Person(BaseModel):
class NameCollector(BaseModel):
"""record names of all people mentioned"""
"""record names of all people mentioned."""
names: list[str] = Field(..., description="all names mentioned")
person: Person = Field(..., description="info about the main subject")

View File

@ -1,4 +1,4 @@
"""Test PydanticOutputParser"""
"""Test PydanticOutputParser."""
from enum import Enum
from typing import Literal, Optional
@ -141,7 +141,6 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel
)
@ -154,7 +153,6 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel
)

View File

@ -1,4 +1,4 @@
"""Test XMLOutputParser"""
"""Test XMLOutputParser."""
import importlib
from collections.abc import AsyncIterator, Iterable
@ -77,7 +77,7 @@ async def _as_iter(iterable: Iterable[str]) -> AsyncIterator[str]:
async def test_root_only_xml_output_parser() -> None:
"""Test XMLOutputParser when xml only contains the root level tag"""
"""Test XMLOutputParser when xml only contains the root level tag."""
xml_parser = XMLOutputParser(parser="xml")
assert xml_parser.parse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
assert await xml_parser.aparse(ROOT_LEVEL_ONLY) == {"body": "Text of the body."}
@ -125,7 +125,6 @@ async def test_xml_output_parser_defused(content: str) -> None:
@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None:
"""Test XMLOutputParser where complete output is not in XML format."""
xml_parser = XMLOutputParser(parser="xml")
with pytest.raises(OutputParserException) as e:

View File

@ -110,7 +110,6 @@ def test_create_chat_prompt_template_from_template_partial() -> None:
def test_create_system_message_prompt_template_from_template_partial() -> None:
"""Create a system message prompt template with partials."""
graph_creator_content = """
Your instructions are:
{instructions}

View File

@ -348,7 +348,8 @@ def test_prompt_from_file() -> None:
def test_prompt_from_file_with_partial_variables() -> None:
"""Test prompt can be successfully constructed from a file
with partial variables."""
with partial variables.
"""
# given
template = "This is a {foo} test {bar}."
partial_variables = {"bar": "baz"}

View File

@ -75,7 +75,6 @@ def _remove_enum(obj: Any) -> None:
def _schema(obj: Any) -> dict:
"""Return the schema of the object."""
if not is_basemodel_subclass(obj):
msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}"
raise TypeError(msg)

View File

@ -67,7 +67,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def test_doubly_set_configurable() -> None:
"""Test that setting a configurable field with a default value works"""
"""Test that setting a configurable field with a default value works."""
runnable = MyRunnable(my_property="a") # type: ignore
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(

View File

@ -314,7 +314,7 @@ class FakeStructuredOutputModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
return ChatResult(generations=[])
def bind_tools(
@ -344,7 +344,7 @@ class FakeModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
return ChatResult(generations=[])
def bind_tools(

View File

@ -426,7 +426,7 @@ def test_runnable_get_graph_with_invalid_output_type() -> None:
def test_graph_mermaid_escape_node_label() -> None:
"""Test that node labels are correctly preprocessed for draw_mermaid"""
"""Test that node labels are correctly preprocessed for draw_mermaid."""
assert _escape_node_label("foo") == "foo"
assert _escape_node_label("foo-bar") == "foo-bar"
assert _escape_node_label("foo_1") == "foo_1"

View File

@ -257,7 +257,7 @@ class LengthChatModel(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
"""Top Level call."""
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=str(len(messages))))]
)

View File

@ -96,7 +96,8 @@ PYDANTIC_VERSION = tuple(map(int, pydantic.__version__.split(".")))
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution.
It replaces run ids with deterministic UUIDs for snapshotting."""
It replaces run ids with deterministic UUIDs for snapshotting.
"""
def __init__(self) -> None:
"""Initialize the tracer."""
@ -158,7 +159,6 @@ class FakeTracer(BaseTracer):
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> list[Run]:
@ -657,7 +657,7 @@ def test_lambda_schemas(snapshot: SnapshotAssertion) -> None:
def test_with_types_with_type_generics() -> None:
"""Verify that with_types works if we use things like List[int]"""
"""Verify that with_types works if we use things like List[int]."""
def foo(x: int) -> None:
"""Add one to the input."""
@ -3334,7 +3334,6 @@ def test_with_config_with_config() -> None:
def test_metadata_is_merged() -> None:
"""Test metadata and tags defined in with_config and at are merged/concatend."""
foo = RunnableLambda(lambda x: x).with_config({"metadata": {"my_key": "my_value"}})
expected_metadata = {
"my_key": "my_value",
@ -3349,7 +3348,6 @@ def test_metadata_is_merged() -> None:
def test_tags_are_appended() -> None:
"""Test tags from with_config are concatenated with those in invocation."""
foo = RunnableLambda(lambda x: x).with_config({"tags": ["my_key"]})
with collect_runs() as cb:
foo.invoke("hi", {"tags": ["invoked_key"]})
@ -4445,7 +4443,6 @@ async def test_runnable_branch_abatch() -> None:
def test_runnable_branch_stream() -> None:
"""Verify that stream works for RunnableBranch."""
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
@ -4503,7 +4500,6 @@ def test_runnable_branch_stream_with_callbacks() -> None:
async def test_runnable_branch_astream() -> None:
"""Verify that astream works for RunnableBranch."""
llm_res = "i'm a textbot"
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
@ -4694,8 +4690,8 @@ async def test_runnable_gen() -> None:
async def test_runnable_gen_context_config() -> None:
"""Test that a generator can call other runnables with config
propagated from the context."""
propagated from the context.
"""
fake = RunnableLambda(len)
def gen(input: Iterator[Any]) -> Iterator[int]:
@ -4829,8 +4825,8 @@ async def test_runnable_gen_context_config() -> None:
async def test_runnable_iter_context_config() -> None:
"""Test that a generator can call other runnables with config
propagated from the context."""
propagated from the context.
"""
fake = RunnableLambda(len)
@chain
@ -4946,8 +4942,8 @@ async def test_runnable_iter_context_config() -> None:
async def test_runnable_lambda_context_config() -> None:
"""Test that a function can call other runnables with config
propagated from the context."""
propagated from the context.
"""
fake = RunnableLambda(len)
@chain
@ -5098,7 +5094,8 @@ def test_with_config_callbacks() -> None:
async def test_ainvoke_on_returned_runnable() -> None:
"""Verify that a runnable returned by a sync runnable in the async path will
be runthroughaasync path (issue #13407)"""
be runthroughaasync path (issue #13407).
"""
def idchain_sync(__input: dict) -> bool:
return False
@ -5171,7 +5168,7 @@ async def test_astream_log_deep_copies() -> None:
"""
def _get_run_log(run_log_patches: Sequence[RunLogPatch]) -> RunLog:
"""Get run log"""
"""Get run log."""
run_log = RunLog(state=None) # type: ignore
for log_patch in run_log_patches:
run_log = run_log + log_patch
@ -5435,7 +5432,6 @@ def test_pydantic_protected_namespaces() -> None:
def test_schema_for_prompt_and_chat_model() -> None:
"""Testing that schema is generated properly when using variable names
that collide with pydantic attributes.
"""
prompt = ChatPromptTemplate([("system", "{model_json_schema}, {_private}, {json}")])

View File

@ -80,15 +80,15 @@ def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool"""
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
"""Foo"""
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
"""Hello Doc"""
"""Hello Doc."""
return [Document(page_content="hello")]
chain = RunnableLambda(foo) | get_docs
@ -345,7 +345,7 @@ async def test_event_stream_with_triple_lambda() -> None:
async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"""Test filtering based on tags / names"""
"""Test filtering based on tags / names."""
def reverse(s: str) -> str:
"""Reverse a string."""
@ -1822,7 +1822,7 @@ async def test_runnable_each() -> None:
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
"""Test that astream events support accepting config."""
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
@ -1912,7 +1912,7 @@ async def test_runnable_with_message_history() -> None:
store: dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history"""
"""Get a chat message history."""
if session_id not in store:
store[session_id] = []
return InMemoryHistory(messages=store[session_id])

View File

@ -90,15 +90,15 @@ async def _collect_events(
async def test_event_stream_with_simple_function_tool() -> None:
"""Test the event stream with a function and tool"""
"""Test the event stream with a function and tool."""
def foo(x: int) -> dict:
"""Foo"""
"""Foo."""
return {"x": 5}
@tool
def get_docs(x: int) -> list[Document]:
"""Hello Doc"""
"""Hello Doc."""
return [Document(page_content="hello")]
chain = RunnableLambda(foo) | get_docs
@ -371,7 +371,7 @@ async def test_event_stream_exception() -> None:
async def test_event_stream_with_triple_lambda_test_filtering() -> None:
"""Test filtering based on tags / names"""
"""Test filtering based on tags / names."""
def reverse(s: str) -> str:
"""Reverse a string."""
@ -1767,7 +1767,7 @@ async def test_runnable_each() -> None:
async def test_events_astream_config() -> None:
"""Test that astream events support accepting config"""
"""Test that astream events support accepting config."""
infinite_cycle = cycle([AIMessage(content="hello world!", id="ai1")])
good_world_on_repeat = cycle([AIMessage(content="Goodbye world", id="ai2")])
model = GenericFakeChatModel(messages=infinite_cycle).configurable_fields(
@ -1859,7 +1859,7 @@ async def test_runnable_with_message_history() -> None:
store: dict = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history"""
"""Get a chat message history."""
if session_id not in store:
store[session_id] = []
return InMemoryHistory(messages=store[session_id])
@ -2046,7 +2046,7 @@ async def test_sync_in_sync_lambdas() -> None:
class StreamingRunnable(Runnable[Input, Output]):
"""A custom runnable used for testing purposes"""
"""A custom runnable used for testing purposes."""
iterable: Iterable[Any]
@ -2734,7 +2734,7 @@ async def test_custom_event_root_dispatch_with_in_tool() -> None:
@tool
async def foo(x: int) -> int:
"""Foo"""
"""Foo."""
await adispatch_custom_event("event1", {"x": x})
return x + 1

View File

@ -23,7 +23,7 @@ from langchain_core.runnables.utils import (
],
)
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
"""Test get_lambda_source function"""
"""Test get_lambda_source function."""
source = get_lambda_source(func)
assert source == expected_source
@ -36,7 +36,7 @@ def test_get_lambda_source(func: Callable, expected_source: str) -> None:
],
)
def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) -> None:
"""Test indent_lines_after_first function"""
"""Test indent_lines_after_first function."""
indented_text = indent_lines_after_first(text, prefix)
assert indented_text == expected_output

View File

@ -25,7 +25,6 @@ from langchain_core.messages import (
def test_serde_any_message() -> None:
"""Test AnyMessage() serder."""
lc_objects = [
HumanMessage(content="human"),
HumanMessageChunk(content="human"),

View File

@ -388,7 +388,6 @@ def test_base_tool_inheritance_base_schema() -> None:
def test_tool_lambda_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
name="tool",
description="A tool",
@ -403,7 +402,7 @@ def test_structured_tool_from_function_docstring() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: str) -> str:
"""Docstring
"""Docstring.
Args:
bar: the bar value
@ -437,7 +436,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None:
"""Test that structured tools can be created from functions."""
def foo(bar: int, baz: list[str]) -> str:
"""Docstring
"""Docstring.
Args:
bar: int
@ -526,7 +525,7 @@ def test_tool_from_function_with_run_manager() -> None:
def foo(bar: str, callbacks: Optional[CallbackManagerForToolRun] = None) -> str:
"""Docstring
Args:
bar: str
bar: str.
"""
assert callbacks is not None
return "foo" + bar
@ -544,7 +543,7 @@ def test_structured_tool_from_function_with_run_manager() -> None:
def foo(
bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Docstring
"""Docstring.
Args:
bar: int
@ -1381,7 +1380,7 @@ class _MockStructuredToolWithRawOutput(BaseTool):
def _mock_structured_tool_with_artifact(
arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> tuple[str, dict]:
"""A Structured Tool"""
"""A Structured Tool."""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@ -1891,7 +1890,7 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
from langchain_core.tools import StructuredTool
def foo(a: int, b: str) -> str:
"""Hahaha"""
"""Hahaha."""
return "foo"
foo_tool = StructuredTool.from_function(
@ -2187,11 +2186,11 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
@tool(args_schema=Foo)
def foo(x): # type: ignore[no-untyped-def]
"""foo"""
"""Foo."""
return x
assert foo.tool_call_schema.model_json_schema() == {
"description": "foo",
"description": "Foo.",
"properties": {
"x": {
"description": "List of integers",
@ -2269,7 +2268,7 @@ def test_injected_arg_with_complex_type() -> None:
def test_tool_injected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage:
"""foo"""
"""Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo.invoke(
@ -2281,7 +2280,7 @@ def test_tool_injected_tool_call_id() -> None:
@tool
def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage:
"""foo"""
"""Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
assert foo2.invoke(
@ -2292,7 +2291,7 @@ def test_tool_injected_tool_call_id() -> None:
def test_tool_uninjected_tool_call_id() -> None:
@tool
def foo(x: int, tool_call_id: str) -> ToolMessage:
"""foo"""
"""Foo."""
return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore
with pytest.raises(ValueError):

View File

@ -83,7 +83,6 @@ def test_tracer_with_run_tree_parent() -> None:
def test_log_lock() -> None:
"""Test that example assigned at callback start/end is honored."""
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
@ -96,9 +95,7 @@ def test_log_lock() -> None:
class LangChainProjectNameTest(unittest.TestCase):
"""
Test that the project name is set correctly for runs.
"""
"""Test that the project name is set correctly for runs."""
class SetProperTracerProjectTestCase:
def __init__(

View File

@ -39,7 +39,7 @@ from langchain_core.utils.function_calling import (
@pytest.fixture()
def pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""dummy function"""
"""Dummy function."""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
@ -53,7 +53,7 @@ def annotated_function() -> Callable:
arg1: ExtensionsAnnotated[int, "foo"],
arg2: ExtensionsAnnotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
) -> None:
"""dummy function"""
"""Dummy function."""
return dummy_function
@ -61,7 +61,7 @@ def annotated_function() -> Callable:
@pytest.fixture()
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
"""Dummy function.
Args:
arg1: foo
@ -74,7 +74,7 @@ def function() -> Callable:
@pytest.fixture()
def function_docstring_annotations() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
"""Dummy function.
Args:
arg1 (int): foo
@ -105,7 +105,7 @@ def dummy_tool() -> BaseTool:
class DummyFunction(BaseTool):
args_schema: type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"
description: str = "Dummy function."
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
@ -122,7 +122,7 @@ def dummy_structured_tool() -> StructuredTool:
return StructuredTool.from_function(
lambda x: None,
name="dummy_function",
description="dummy function",
description="Dummy function.",
args_schema=Schema,
)
@ -130,7 +130,7 @@ def dummy_structured_tool() -> StructuredTool:
@pytest.fixture()
def dummy_pydantic() -> type[BaseModel]:
class dummy_function(BaseModel): # noqa: N801
"""dummy function"""
"""Dummy function."""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
@ -141,7 +141,7 @@ def dummy_pydantic() -> type[BaseModel]:
@pytest.fixture()
def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
class dummy_function(BaseModelV2Maybe): # noqa: N801
"""dummy function"""
"""Dummy function."""
arg1: int = FieldV2Maybe(..., description="foo")
arg2: Literal["bar", "baz"] = FieldV2Maybe(
@ -154,7 +154,7 @@ def dummy_pydantic_v2() -> type[BaseModelV2Maybe]:
@pytest.fixture()
def dummy_typing_typed_dict() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function"""
"""Dummy function."""
arg1: TypingAnnotated[int, ..., "foo"] # noqa: F821
arg2: TypingAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"] # noqa: F722
@ -165,7 +165,7 @@ def dummy_typing_typed_dict() -> type:
@pytest.fixture()
def dummy_typing_typed_dict_docstring() -> type:
class dummy_function(TypingTypedDict): # noqa: N801
"""dummy function
"""Dummy function.
Args:
arg1: foo
@ -181,7 +181,7 @@ def dummy_typing_typed_dict_docstring() -> type:
@pytest.fixture()
def dummy_extensions_typed_dict() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function"""
"""Dummy function."""
arg1: ExtensionsAnnotated[int, ..., "foo"]
arg2: ExtensionsAnnotated[Literal["bar", "baz"], ..., "one of 'bar', 'baz'"]
@ -192,7 +192,7 @@ def dummy_extensions_typed_dict() -> type:
@pytest.fixture()
def dummy_extensions_typed_dict_docstring() -> type:
class dummy_function(ExtensionsTypedDict): # noqa: N801
"""dummy function
"""Dummy function.
Args:
arg1: foo
@ -209,7 +209,7 @@ def dummy_extensions_typed_dict_docstring() -> type:
def json_schema() -> dict:
return {
"title": "dummy_function",
"description": "dummy function",
"description": "Dummy function.",
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
@ -227,7 +227,7 @@ def json_schema() -> dict:
def anthropic_tool() -> dict:
return {
"name": "dummy_function",
"description": "dummy function",
"description": "Dummy function.",
"input_schema": {
"type": "object",
"properties": {
@ -248,7 +248,7 @@ def bedrock_converse_tool() -> dict:
return {
"toolSpec": {
"name": "dummy_function",
"description": "dummy function",
"description": "Dummy function.",
"inputSchema": {
"json": {
"type": "object",
@ -269,7 +269,7 @@ def bedrock_converse_tool() -> dict:
class Dummy:
def dummy_function(self, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
"""Dummy function.
Args:
arg1: foo
@ -280,7 +280,7 @@ class Dummy:
class DummyWithClassMethod:
@classmethod
def dummy_function(cls, arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
"""Dummy function.
Args:
arg1: foo
@ -307,7 +307,7 @@ def test_convert_to_openai_function(
) -> None:
expected = {
"name": "dummy_function",
"description": "dummy function",
"description": "Dummy function.",
"parameters": {
"type": "object",
"properties": {
@ -345,7 +345,7 @@ def test_convert_to_openai_function(
assert actual == expected
# Test runnables
actual = convert_to_openai_function(runnable.as_tool(description="dummy function"))
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
parameters = {
"type": "object",
"properties": {
@ -392,7 +392,7 @@ def test_convert_to_openai_function_nested_v2() -> None:
)
def my_function(arg1: NestedV2) -> None:
"""dummy function"""
"""Dummy function."""
convert_to_openai_function(my_function)
@ -405,11 +405,11 @@ def test_convert_to_openai_function_nested() -> None:
)
def my_function(arg1: Nested) -> None:
"""dummy function"""
"""Dummy function."""
expected = {
"name": "my_function",
"description": "dummy function",
"description": "Dummy function.",
"parameters": {
"type": "object",
"properties": {
@ -442,11 +442,11 @@ def test_convert_to_openai_function_nested_strict() -> None:
)
def my_function(arg1: Nested) -> None:
"""dummy function"""
"""Dummy function."""
expected = {
"name": "my_function",
"description": "dummy function",
"description": "Dummy function.",
"parameters": {
"type": "object",
"properties": {
@ -608,7 +608,7 @@ def test_function_optional_param() -> None:
b: str,
c: Optional[list[Optional[str]]],
) -> None:
"""A test function"""
"""A test function."""
func = convert_to_openai_function(func5)
req = func["parameters"]["required"]
@ -617,7 +617,7 @@ def test_function_optional_param() -> None:
def test_function_no_params() -> None:
def nullary_function() -> None:
"""nullary function"""
"""Nullary function."""
func = convert_to_openai_function(nullary_function)
req = func["parameters"].get("required")
@ -722,12 +722,12 @@ def test__convert_typed_dict_to_openai_function(
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
class SubTool(typed_dict):
"""Subtool docstring"""
"""Subtool docstring."""
args: annotated[dict[str, Any], {}, "this does bar"] # noqa: F722 # type: ignore
class Tool(typed_dict):
"""Docstring
"""Docstring.
Args:
arg1: foo
@ -753,7 +753,7 @@ def test__convert_typed_dict_to_openai_function(
expected = {
"name": "Tool",
"description": "Docstring",
"description": "Docstring.",
"parameters": {
"type": "object",
"properties": {
@ -768,7 +768,7 @@ def test__convert_typed_dict_to_openai_function(
"arg3": {
"type": "array",
"items": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -798,7 +798,7 @@ def test__convert_typed_dict_to_openai_function(
{"type": "array", "items": {}},
{
"title": "SubTool",
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -816,7 +816,7 @@ def test__convert_typed_dict_to_openai_function(
"arg7": {
"type": "array",
"items": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -834,7 +834,7 @@ def test__convert_typed_dict_to_openai_function(
"items": [
{
"title": "SubTool",
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -850,7 +850,7 @@ def test__convert_typed_dict_to_openai_function(
"arg9": {
"type": "array",
"items": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -864,7 +864,7 @@ def test__convert_typed_dict_to_openai_function(
"arg10": {
"type": "array",
"items": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -878,7 +878,7 @@ def test__convert_typed_dict_to_openai_function(
"arg11": {
"type": "array",
"items": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -893,7 +893,7 @@ def test__convert_typed_dict_to_openai_function(
"arg12": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -907,7 +907,7 @@ def test__convert_typed_dict_to_openai_function(
"arg13": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -921,7 +921,7 @@ def test__convert_typed_dict_to_openai_function(
"arg14": {
"type": "object",
"additionalProperties": {
"description": "Subtool docstring",
"description": "Subtool docstring.",
"type": "object",
"properties": {
"args": {
@ -981,13 +981,13 @@ def test_convert_union_type_py_39() -> None:
def test_convert_to_openai_function_no_args() -> None:
@tool
def empty_tool() -> str:
"""No args"""
"""No args."""
return "foo"
actual = convert_to_openai_function(empty_tool, strict=True)
assert actual == {
"name": "empty_tool",
"description": "No args",
"description": "No args.",
"parameters": {
"properties": {},
"additionalProperties": False,

View File

@ -140,7 +140,7 @@ def test_is_basemodel_instance() -> None:
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None:
"""Test pydantic with field metadata"""
"""Test pydantic with field metadata."""
from pydantic import BaseModel as BaseModelV2
from pydantic import Field as FieldV2
@ -202,7 +202,6 @@ def test_fields_pydantic_v1_from_2() -> None:
def test_create_model_v2() -> None:
"""Test that create model v2 works as expected."""
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always") # Cause all warnings to always be triggered
foo = create_model_v2("Foo", field_definitions={"a": (int, None)})

View File

@ -35,7 +35,7 @@ async def test_inmemory_similarity_search() -> None:
async def test_inmemory_similarity_search_with_score() -> None:
"""Test end to end similarity search with score"""
"""Test end to end similarity search with score."""
store = await InMemoryVectorStore.afrom_texts(
["foo", "bar", "baz"], DeterministicFakeEmbedding(size=3)
)
@ -63,7 +63,7 @@ async def test_add_by_ids() -> None:
async def test_inmemory_mmr() -> None:
"""Test MMR search"""
"""Test MMR search."""
texts = ["foo", "foo", "fou", "foy"]
docsearch = await InMemoryVectorStore.afrom_texts(
texts, DeterministicFakeEmbedding(size=6)
@ -147,7 +147,6 @@ async def test_inmemory_upsert() -> None:
async def test_inmemory_get_by_ids() -> None:
"""Test get by ids."""
store = InMemoryVectorStore(embedding=DeterministicFakeEmbedding(size=3))
store.upsert(

View File

@ -117,7 +117,6 @@ def test_default_add_documents(vs_class: type[VectorStore]) -> None:
"""Test that we can implement the upsert method of the CustomVectorStore
class without violating the Liskov Substitution Principle.
"""
store = vs_class()
# Check upsert with id