mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 23:54:14 +00:00
core: Add ruff rules RUF (#29353)
See https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf Mostly: * [RUF022](https://docs.astral.sh/ruff/rules/unsorted-dunder-all/) (unsorted `__all__`) * [RUF100](https://docs.astral.sh/ruff/rules/unused-noqa/) (unused noqa) * [RUF021](https://docs.astral.sh/ruff/rules/parenthesize-chained-operators/) (parenthesize-chained-operators) * [RUF015](https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element/) (unnecessary-iterable-allocation-for-first-element) * [RUF005](https://docs.astral.sh/ruff/rules/collection-literal-concatenation/) (collection-literal-concatenation) * [RUF046](https://docs.astral.sh/ruff/rules/unnecessary-cast-to-int/) (unnecessary-cast-to-int) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
6cd1aadf60
commit
a8f2ddee31
@ -30,15 +30,15 @@ if TYPE_CHECKING:
|
|||||||
from .path import as_import_path, get_relative_path
|
from .path import as_import_path, get_relative_path
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
"LangChainBetaWarning",
|
||||||
|
"LangChainDeprecationWarning",
|
||||||
"as_import_path",
|
"as_import_path",
|
||||||
"beta",
|
"beta",
|
||||||
"deprecated",
|
"deprecated",
|
||||||
"get_relative_path",
|
"get_relative_path",
|
||||||
"LangChainBetaWarning",
|
|
||||||
"LangChainDeprecationWarning",
|
|
||||||
"suppress_langchain_beta_warning",
|
"suppress_langchain_beta_warning",
|
||||||
"surface_langchain_beta_warnings",
|
|
||||||
"suppress_langchain_deprecation_warning",
|
"suppress_langchain_deprecation_warning",
|
||||||
|
"surface_langchain_beta_warnings",
|
||||||
"surface_langchain_deprecation_warnings",
|
"surface_langchain_deprecation_warnings",
|
||||||
"warn_deprecated",
|
"warn_deprecated",
|
||||||
)
|
)
|
||||||
|
@ -54,39 +54,39 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"dispatch_custom_event",
|
|
||||||
"adispatch_custom_event",
|
|
||||||
"RetrieverManagerMixin",
|
|
||||||
"LLMManagerMixin",
|
|
||||||
"ChainManagerMixin",
|
|
||||||
"ToolManagerMixin",
|
|
||||||
"Callbacks",
|
|
||||||
"CallbackManagerMixin",
|
|
||||||
"RunManagerMixin",
|
|
||||||
"BaseCallbackHandler",
|
|
||||||
"AsyncCallbackHandler",
|
"AsyncCallbackHandler",
|
||||||
"BaseCallbackManager",
|
|
||||||
"BaseRunManager",
|
|
||||||
"RunManager",
|
|
||||||
"ParentRunManager",
|
|
||||||
"AsyncRunManager",
|
|
||||||
"AsyncParentRunManager",
|
|
||||||
"CallbackManagerForLLMRun",
|
|
||||||
"AsyncCallbackManagerForLLMRun",
|
|
||||||
"CallbackManagerForChainRun",
|
|
||||||
"AsyncCallbackManagerForChainRun",
|
|
||||||
"CallbackManagerForToolRun",
|
|
||||||
"AsyncCallbackManagerForToolRun",
|
|
||||||
"CallbackManagerForRetrieverRun",
|
|
||||||
"AsyncCallbackManagerForRetrieverRun",
|
|
||||||
"CallbackManager",
|
|
||||||
"CallbackManagerForChainGroup",
|
|
||||||
"AsyncCallbackManager",
|
"AsyncCallbackManager",
|
||||||
"AsyncCallbackManagerForChainGroup",
|
"AsyncCallbackManagerForChainGroup",
|
||||||
|
"AsyncCallbackManagerForChainRun",
|
||||||
|
"AsyncCallbackManagerForLLMRun",
|
||||||
|
"AsyncCallbackManagerForRetrieverRun",
|
||||||
|
"AsyncCallbackManagerForToolRun",
|
||||||
|
"AsyncParentRunManager",
|
||||||
|
"AsyncRunManager",
|
||||||
|
"BaseCallbackHandler",
|
||||||
|
"BaseCallbackManager",
|
||||||
|
"BaseRunManager",
|
||||||
|
"CallbackManager",
|
||||||
|
"CallbackManagerForChainGroup",
|
||||||
|
"CallbackManagerForChainRun",
|
||||||
|
"CallbackManagerForLLMRun",
|
||||||
|
"CallbackManagerForRetrieverRun",
|
||||||
|
"CallbackManagerForToolRun",
|
||||||
|
"CallbackManagerMixin",
|
||||||
|
"Callbacks",
|
||||||
|
"ChainManagerMixin",
|
||||||
|
"FileCallbackHandler",
|
||||||
|
"LLMManagerMixin",
|
||||||
|
"ParentRunManager",
|
||||||
|
"RetrieverManagerMixin",
|
||||||
|
"RunManager",
|
||||||
|
"RunManagerMixin",
|
||||||
"StdOutCallbackHandler",
|
"StdOutCallbackHandler",
|
||||||
"StreamingStdOutCallbackHandler",
|
"StreamingStdOutCallbackHandler",
|
||||||
"FileCallbackHandler",
|
"ToolManagerMixin",
|
||||||
"UsageMetadataCallbackHandler",
|
"UsageMetadataCallbackHandler",
|
||||||
|
"adispatch_custom_event",
|
||||||
|
"dispatch_custom_event",
|
||||||
"get_usage_metadata_callback",
|
"get_usage_metadata_callback",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,8 +14,8 @@ __all__ = (
|
|||||||
"BaseLoader",
|
"BaseLoader",
|
||||||
"Blob",
|
"Blob",
|
||||||
"BlobLoader",
|
"BlobLoader",
|
||||||
"PathLike",
|
|
||||||
"LangSmithLoader",
|
"LangSmithLoader",
|
||||||
|
"PathLike",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||||||
from .compressor import BaseDocumentCompressor
|
from .compressor import BaseDocumentCompressor
|
||||||
from .transformers import BaseDocumentTransformer
|
from .transformers import BaseDocumentTransformer
|
||||||
|
|
||||||
__all__ = ("Document", "BaseDocumentTransformer", "BaseDocumentCompressor")
|
__all__ = ("BaseDocumentCompressor", "BaseDocumentTransformer", "Document")
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
"Document": "base",
|
"Document": "base",
|
||||||
|
@ -20,14 +20,14 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"aindex",
|
|
||||||
"DeleteResponse",
|
"DeleteResponse",
|
||||||
"DocumentIndex",
|
"DocumentIndex",
|
||||||
"index",
|
|
||||||
"IndexingResult",
|
|
||||||
"InMemoryRecordManager",
|
"InMemoryRecordManager",
|
||||||
|
"IndexingResult",
|
||||||
"RecordManager",
|
"RecordManager",
|
||||||
"UpsertResponse",
|
"UpsertResponse",
|
||||||
|
"aindex",
|
||||||
|
"index",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -68,22 +68,22 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseLanguageModel",
|
|
||||||
"BaseChatModel",
|
|
||||||
"SimpleChatModel",
|
|
||||||
"BaseLLM",
|
|
||||||
"LLM",
|
"LLM",
|
||||||
"LanguageModelInput",
|
"BaseChatModel",
|
||||||
"get_tokenizer",
|
"BaseLLM",
|
||||||
"LangSmithParams",
|
"BaseLanguageModel",
|
||||||
"LanguageModelOutput",
|
|
||||||
"LanguageModelLike",
|
|
||||||
"FakeListLLM",
|
|
||||||
"FakeStreamingListLLM",
|
|
||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
|
"FakeListLLM",
|
||||||
"FakeMessagesListChatModel",
|
"FakeMessagesListChatModel",
|
||||||
|
"FakeStreamingListLLM",
|
||||||
"GenericFakeChatModel",
|
"GenericFakeChatModel",
|
||||||
|
"LangSmithParams",
|
||||||
|
"LanguageModelInput",
|
||||||
|
"LanguageModelLike",
|
||||||
|
"LanguageModelOutput",
|
||||||
"ParrotFakeChatModel",
|
"ParrotFakeChatModel",
|
||||||
|
"SimpleChatModel",
|
||||||
|
"get_tokenizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -23,7 +23,7 @@ def _is_openai_data_block(block: dict) -> bool:
|
|||||||
if isinstance(file_data, str):
|
if isinstance(file_data, str):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif block.get("type") == "input_audio": # noqa: SIM102
|
elif block.get("type") == "input_audio":
|
||||||
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
|
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
|
||||||
audio_data = input_audio.get("data")
|
audio_data = input_audio.get("data")
|
||||||
audio_format = input_audio.get("format")
|
audio_format = input_audio.get("format")
|
||||||
|
@ -354,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def invoke(
|
def invoke(
|
||||||
@ -1203,7 +1203,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
return generation.message
|
return generation.message
|
||||||
msg = "Unexpected generation type"
|
msg = "Unexpected generation type"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
async def _call_async(
|
async def _call_async(
|
||||||
self,
|
self,
|
||||||
@ -1219,7 +1219,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if isinstance(generation, ChatGeneration):
|
if isinstance(generation, ChatGeneration):
|
||||||
return generation.message
|
return generation.message
|
||||||
msg = "Unexpected generation type"
|
msg = "Unexpected generation type"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
def call_as_llm(
|
def call_as_llm(
|
||||||
@ -1261,7 +1261,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
return result.content
|
return result.content
|
||||||
msg = "Cannot use predict when output is not a string."
|
msg = "Cannot use predict when output is not a string."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||||
@override
|
@override
|
||||||
@ -1287,7 +1287,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if isinstance(result.content, str):
|
if isinstance(result.content, str):
|
||||||
return result.content
|
return result.content
|
||||||
msg = "Cannot use predict when output is not a string."
|
msg = "Cannot use predict when output is not a string."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||||
@override
|
@override
|
||||||
|
@ -103,7 +103,9 @@ def create_base_retry_decorator(
|
|||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
loop.create_task(coro)
|
# TODO: Fix RUF006 - this task should have a reference
|
||||||
|
# and be awaited somewhere
|
||||||
|
loop.create_task(coro) # noqa: RUF006
|
||||||
else:
|
else:
|
||||||
asyncio.run(coro)
|
asyncio.run(coro)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -336,7 +338,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
f"Invalid input type {type(input)}. "
|
f"Invalid input type {type(input)}. "
|
||||||
"Must be a PromptValue, str, or list of BaseMessages."
|
"Must be a PromptValue, str, or list of BaseMessages."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
def _get_ls_params(
|
def _get_ls_params(
|
||||||
self,
|
self,
|
||||||
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||||||
# the `from langchain_core.load.load import load` absolute import should also work.
|
# the `from langchain_core.load.load import load` absolute import should also work.
|
||||||
from langchain_core.load.load import load
|
from langchain_core.load.load import load
|
||||||
|
|
||||||
__all__ = ("dumpd", "dumps", "load", "loads", "Serializable")
|
__all__ = ("Serializable", "dumpd", "dumps", "load", "loads")
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
"dumpd": "dump",
|
"dumpd": "dump",
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
This module contains memory abstractions from LangChain v0.0.x.
|
This module contains memory abstractions from LangChain v0.0.x.
|
||||||
|
|
||||||
These abstractions are now deprecated and will be removed in LangChain v1.0.0.
|
These abstractions are now deprecated and will be removed in LangChain v1.0.0.
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
@ -76,28 +76,28 @@ __all__ = (
|
|||||||
"HumanMessageChunk",
|
"HumanMessageChunk",
|
||||||
"InvalidToolCall",
|
"InvalidToolCall",
|
||||||
"MessageLikeRepresentation",
|
"MessageLikeRepresentation",
|
||||||
|
"RemoveMessage",
|
||||||
"SystemMessage",
|
"SystemMessage",
|
||||||
"SystemMessageChunk",
|
"SystemMessageChunk",
|
||||||
"ToolCall",
|
"ToolCall",
|
||||||
"ToolCallChunk",
|
"ToolCallChunk",
|
||||||
"ToolMessage",
|
"ToolMessage",
|
||||||
"ToolMessageChunk",
|
"ToolMessageChunk",
|
||||||
"RemoveMessage",
|
|
||||||
"_message_from_dict",
|
"_message_from_dict",
|
||||||
|
"convert_to_messages",
|
||||||
"convert_to_openai_data_block",
|
"convert_to_openai_data_block",
|
||||||
"convert_to_openai_image_block",
|
"convert_to_openai_image_block",
|
||||||
"convert_to_messages",
|
"convert_to_openai_messages",
|
||||||
|
"filter_messages",
|
||||||
"get_buffer_string",
|
"get_buffer_string",
|
||||||
"is_data_content_block",
|
"is_data_content_block",
|
||||||
"merge_content",
|
"merge_content",
|
||||||
|
"merge_message_runs",
|
||||||
"message_chunk_to_message",
|
"message_chunk_to_message",
|
||||||
"message_to_dict",
|
"message_to_dict",
|
||||||
"messages_from_dict",
|
"messages_from_dict",
|
||||||
"messages_to_dict",
|
"messages_to_dict",
|
||||||
"filter_messages",
|
|
||||||
"merge_message_runs",
|
|
||||||
"trim_messages",
|
"trim_messages",
|
||||||
"convert_to_openai_messages",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -423,13 +423,13 @@ def add_ai_message_chunks(
|
|||||||
|
|
||||||
id = None
|
id = None
|
||||||
candidates = [left.id] + [o.id for o in others]
|
candidates = [left.id] + [o.id for o in others]
|
||||||
# first pass: pick the first non‐run-* id
|
# first pass: pick the first non-run-* id
|
||||||
for id_ in candidates:
|
for id_ in candidates:
|
||||||
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
||||||
id = id_
|
id = id_
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# second pass: no provider-assigned id found, just take the first non‐null
|
# second pass: no provider-assigned id found, just take the first non-null
|
||||||
for id_ in candidates:
|
for id_ in candidates:
|
||||||
if id_:
|
if id_:
|
||||||
id = id_
|
id = id_
|
||||||
|
@ -101,8 +101,7 @@ class BaseMessage(Serializable):
|
|||||||
block
|
block
|
||||||
for block in self.content
|
for block in self.content
|
||||||
if isinstance(block, str)
|
if isinstance(block, str)
|
||||||
or block.get("type") == "text"
|
or (block.get("type") == "text" and isinstance(block.get("text"), str))
|
||||||
and isinstance(block.get("text"), str)
|
|
||||||
]
|
]
|
||||||
return "".join(
|
return "".join(
|
||||||
block if isinstance(block, str) else block["text"] for block in blocks
|
block if isinstance(block, str) else block["text"] for block in blocks
|
||||||
@ -161,7 +160,7 @@ def merge_content(
|
|||||||
merged += content
|
merged += content
|
||||||
# If the next chunk is a list, add the current to the start of the list
|
# If the next chunk is a list, add the current to the start of the list
|
||||||
else:
|
else:
|
||||||
merged = [merged] + content # type: ignore[assignment,operator]
|
merged = [merged, *content]
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
# If both are lists
|
# If both are lists
|
||||||
merged = merge_lists(cast("list", merged), content) # type: ignore[assignment]
|
merged = merge_lists(cast("list", merged), content) # type: ignore[assignment]
|
||||||
|
@ -885,7 +885,7 @@ def trim_messages(
|
|||||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||||
elif callable(token_counter):
|
elif callable(token_counter):
|
||||||
if (
|
if (
|
||||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
next(iter(inspect.signature(token_counter).parameters.values())).annotation
|
||||||
is BaseMessage
|
is BaseMessage
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -1460,7 +1460,7 @@ def _last_max_tokens(
|
|||||||
# Re-reverse the messages and add back the system message if needed
|
# Re-reverse the messages and add back the system message if needed
|
||||||
result = reversed_result[::-1]
|
result = reversed_result[::-1]
|
||||||
if system_message:
|
if system_message:
|
||||||
result = [system_message] + result
|
result = [system_message, *result]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -1543,7 +1543,7 @@ def _get_message_openai_role(message: BaseMessage) -> str:
|
|||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
return message.role
|
return message.role
|
||||||
msg = f"Unknown BaseMessage type {message.__class__}."
|
msg = f"Unknown BaseMessage type {message.__class__}."
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
|
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
|
||||||
|
@ -47,23 +47,23 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseLLMOutputParser",
|
|
||||||
"BaseGenerationOutputParser",
|
|
||||||
"BaseOutputParser",
|
|
||||||
"ListOutputParser",
|
|
||||||
"CommaSeparatedListOutputParser",
|
|
||||||
"NumberedListOutputParser",
|
|
||||||
"MarkdownListOutputParser",
|
|
||||||
"StrOutputParser",
|
|
||||||
"BaseTransformOutputParser",
|
|
||||||
"BaseCumulativeTransformOutputParser",
|
"BaseCumulativeTransformOutputParser",
|
||||||
"SimpleJsonOutputParser",
|
"BaseGenerationOutputParser",
|
||||||
"XMLOutputParser",
|
"BaseLLMOutputParser",
|
||||||
"JsonOutputParser",
|
"BaseOutputParser",
|
||||||
"PydanticOutputParser",
|
"BaseTransformOutputParser",
|
||||||
"JsonOutputToolsParser",
|
"CommaSeparatedListOutputParser",
|
||||||
"JsonOutputKeyToolsParser",
|
"JsonOutputKeyToolsParser",
|
||||||
|
"JsonOutputParser",
|
||||||
|
"JsonOutputToolsParser",
|
||||||
|
"ListOutputParser",
|
||||||
|
"MarkdownListOutputParser",
|
||||||
|
"NumberedListOutputParser",
|
||||||
|
"PydanticOutputParser",
|
||||||
"PydanticToolsParser",
|
"PydanticToolsParser",
|
||||||
|
"SimpleJsonOutputParser",
|
||||||
|
"StrOutputParser",
|
||||||
|
"XMLOutputParser",
|
||||||
]
|
]
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -132,6 +132,6 @@ SimpleJsonOutputParser = JsonOutputParser
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"JsonOutputParser",
|
"JsonOutputParser",
|
||||||
"SimpleJsonOutputParser", # For backwards compatibility
|
"SimpleJsonOutputParser", # For backwards compatibility
|
||||||
"parse_partial_json", # For backwards compatibility
|
|
||||||
"parse_and_check_json_markdown", # For backwards compatibility
|
"parse_and_check_json_markdown", # For backwards compatibility
|
||||||
|
"parse_partial_json", # For backwards compatibility
|
||||||
]
|
]
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import re
|
import re
|
||||||
import xml
|
import xml
|
||||||
import xml.etree.ElementTree as ET # noqa: N817
|
import xml.etree.ElementTree as ET
|
||||||
from collections.abc import AsyncIterator, Iterator
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
from xml.etree.ElementTree import TreeBuilder
|
from xml.etree.ElementTree import TreeBuilder
|
||||||
|
@ -70,21 +70,21 @@ __all__ = (
|
|||||||
"ChatMessagePromptTemplate",
|
"ChatMessagePromptTemplate",
|
||||||
"ChatPromptTemplate",
|
"ChatPromptTemplate",
|
||||||
"DictPromptTemplate",
|
"DictPromptTemplate",
|
||||||
|
"FewShotChatMessagePromptTemplate",
|
||||||
"FewShotPromptTemplate",
|
"FewShotPromptTemplate",
|
||||||
"FewShotPromptWithTemplates",
|
"FewShotPromptWithTemplates",
|
||||||
"FewShotChatMessagePromptTemplate",
|
|
||||||
"HumanMessagePromptTemplate",
|
"HumanMessagePromptTemplate",
|
||||||
"MessagesPlaceholder",
|
"MessagesPlaceholder",
|
||||||
"PipelinePromptTemplate",
|
"PipelinePromptTemplate",
|
||||||
"PromptTemplate",
|
"PromptTemplate",
|
||||||
"StringPromptTemplate",
|
"StringPromptTemplate",
|
||||||
"SystemMessagePromptTemplate",
|
"SystemMessagePromptTemplate",
|
||||||
"load_prompt",
|
|
||||||
"format_document",
|
|
||||||
"aformat_document",
|
"aformat_document",
|
||||||
"check_valid_template",
|
"check_valid_template",
|
||||||
|
"format_document",
|
||||||
"get_template_variables",
|
"get_template_variables",
|
||||||
"jinja2_formatter",
|
"jinja2_formatter",
|
||||||
|
"load_prompt",
|
||||||
"validate_jinja2",
|
"validate_jinja2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -445,9 +445,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
prompt = []
|
prompt = []
|
||||||
for tmpl in template:
|
for tmpl in template:
|
||||||
if (
|
if isinstance(tmpl, str) or (
|
||||||
isinstance(tmpl, str)
|
isinstance(tmpl, dict)
|
||||||
or isinstance(tmpl, dict)
|
|
||||||
and "text" in tmpl
|
and "text" in tmpl
|
||||||
and set(tmpl.keys()) <= {"type", "text"}
|
and set(tmpl.keys()) <= {"type", "text"}
|
||||||
):
|
):
|
||||||
@ -524,7 +523,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return cls(prompt=prompt, **kwargs)
|
return cls(prompt=prompt, **kwargs)
|
||||||
msg = f"Invalid template: {template}"
|
msg = f"Invalid template: {template}"
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template_file(
|
def from_template_file(
|
||||||
@ -1000,7 +999,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
if isinstance(
|
if isinstance(
|
||||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||||
):
|
):
|
||||||
return ChatPromptTemplate(messages=self.messages + [other]).partial(
|
return ChatPromptTemplate(messages=[*self.messages, other]).partial(
|
||||||
**partials
|
**partials
|
||||||
)
|
)
|
||||||
if isinstance(other, (list, tuple)):
|
if isinstance(other, (list, tuple)):
|
||||||
@ -1010,7 +1009,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
)
|
)
|
||||||
if isinstance(other, str):
|
if isinstance(other, str):
|
||||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||||
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
|
return ChatPromptTemplate(messages=[*self.messages, prompt]).partial(
|
||||||
**partials
|
**partials
|
||||||
)
|
)
|
||||||
msg = f"Unsupported operand type for +: {type(other)}"
|
msg = f"Unsupported operand type for +: {type(other)}"
|
||||||
|
@ -84,15 +84,11 @@ def _load_examples(config: dict) -> dict:
|
|||||||
|
|
||||||
def _load_output_parser(config: dict) -> dict:
|
def _load_output_parser(config: dict) -> dict:
|
||||||
"""Load output parser."""
|
"""Load output parser."""
|
||||||
if "output_parser" in config and config["output_parser"]:
|
if _config := config.get("output_parser"):
|
||||||
_config = config.pop("output_parser")
|
if output_parser_type := _config.get("_type") != "default":
|
||||||
output_parser_type = _config.pop("_type")
|
|
||||||
if output_parser_type == "default":
|
|
||||||
output_parser = StrOutputParser(**_config)
|
|
||||||
else:
|
|
||||||
msg = f"Unsupported output parser {output_parser_type}"
|
msg = f"Unsupported output parser {output_parser_type}"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
config["output_parser"] = output_parser
|
config["output_parser"] = StrOutputParser(**_config)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,10 +153,8 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
NotImplementedError: If the first element of `others`
|
NotImplementedError: If the first element of `others`
|
||||||
is not a language model.
|
is not a language model.
|
||||||
"""
|
"""
|
||||||
if (
|
if (others and isinstance(others[0], BaseLanguageModel)) or hasattr(
|
||||||
others
|
others[0], "with_structured_output"
|
||||||
and isinstance(others[0], BaseLanguageModel)
|
|
||||||
or hasattr(others[0], "with_structured_output")
|
|
||||||
):
|
):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
self,
|
self,
|
||||||
|
@ -60,19 +60,15 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"chain",
|
|
||||||
"AddableDict",
|
"AddableDict",
|
||||||
"ConfigurableField",
|
"ConfigurableField",
|
||||||
"ConfigurableFieldSingleOption",
|
|
||||||
"ConfigurableFieldMultiOption",
|
"ConfigurableFieldMultiOption",
|
||||||
|
"ConfigurableFieldSingleOption",
|
||||||
"ConfigurableFieldSpec",
|
"ConfigurableFieldSpec",
|
||||||
"ensure_config",
|
|
||||||
"run_in_executor",
|
|
||||||
"patch_config",
|
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
"RunnableSerializable",
|
"RunnableAssign",
|
||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
"RunnableBranch",
|
"RunnableBranch",
|
||||||
"RunnableConfig",
|
"RunnableConfig",
|
||||||
@ -81,14 +77,18 @@ __all__ = (
|
|||||||
"RunnableMap",
|
"RunnableMap",
|
||||||
"RunnableParallel",
|
"RunnableParallel",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
"RunnableAssign",
|
|
||||||
"RunnablePick",
|
"RunnablePick",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
|
"RunnableSerializable",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
"RunnableWithMessageHistory",
|
"RunnableWithMessageHistory",
|
||||||
"get_config_list",
|
|
||||||
"aadd",
|
"aadd",
|
||||||
"add",
|
"add",
|
||||||
|
"chain",
|
||||||
|
"ensure_config",
|
||||||
|
"get_config_list",
|
||||||
|
"patch_config",
|
||||||
|
"run_in_executor",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -2799,7 +2799,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of Runnables.
|
A list of Runnables.
|
||||||
"""
|
"""
|
||||||
return [self.first] + self.middle + [self.last]
|
return [self.first, *self.middle, self.last]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@override
|
@override
|
||||||
@ -3353,7 +3353,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import config_with_context
|
from langchain_core.beta.runnables.context import config_with_context
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first, *self.middle, self.last]
|
||||||
config = config_with_context(config, self.steps)
|
config = config_with_context(config, self.steps)
|
||||||
|
|
||||||
# transform the input stream of each step with the next
|
# transform the input stream of each step with the next
|
||||||
@ -3380,7 +3380,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||||
|
|
||||||
steps = [self.first] + self.middle + [self.last]
|
steps = [self.first, *self.middle, self.last]
|
||||||
config = aconfig_with_context(config, self.steps)
|
config = aconfig_with_context(config, self.steps)
|
||||||
|
|
||||||
# stream the last steps
|
# stream the last steps
|
||||||
@ -4203,7 +4203,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Output]:
|
) -> Iterator[Output]:
|
||||||
if not hasattr(self, "_transform"):
|
if not hasattr(self, "_transform"):
|
||||||
msg = f"{repr(self)} only supports async methods."
|
msg = f"{self!r} only supports async methods."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
return self._transform_stream_with_config(
|
return self._transform_stream_with_config(
|
||||||
input,
|
input,
|
||||||
@ -4238,7 +4238,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Output]:
|
) -> AsyncIterator[Output]:
|
||||||
if not hasattr(self, "_atransform"):
|
if not hasattr(self, "_atransform"):
|
||||||
msg = f"{repr(self)} only supports sync methods."
|
msg = f"{self!r} only supports sync methods."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
return self._atransform_stream_with_config(
|
return self._atransform_stream_with_config(
|
||||||
@ -5781,7 +5781,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
|||||||
bound=self.bound,
|
bound=self.bound,
|
||||||
kwargs=self.kwargs,
|
kwargs=self.kwargs,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
config_factories=[listener_config_factory] + self.config_factories,
|
config_factories=[listener_config_factory, *self.config_factories],
|
||||||
custom_input_type=self.custom_input_type,
|
custom_input_type=self.custom_input_type,
|
||||||
custom_output_type=self.custom_output_type,
|
custom_output_type=self.custom_output_type,
|
||||||
)
|
)
|
||||||
|
@ -562,7 +562,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|||||||
self.which.name or self.which.id,
|
self.which.name or self.which.id,
|
||||||
(
|
(
|
||||||
(v, v)
|
(v, v)
|
||||||
for v in list(self.alternatives.keys()) + [self.default_key]
|
for v in [*list(self.alternatives.keys()), self.default_key]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
_enums_for_spec[self.which] = cast("type[StrEnum]", which_enum)
|
_enums_for_spec[self.which] = cast("type[StrEnum]", which_enum)
|
||||||
|
@ -111,15 +111,15 @@ class AsciiCanvas:
|
|||||||
self.point(x0, y0, char)
|
self.point(x0, y0, char)
|
||||||
elif abs(dx) >= abs(dy):
|
elif abs(dx) >= abs(dy):
|
||||||
for x in range(x0, x1 + 1):
|
for x in range(x0, x1 + 1):
|
||||||
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
|
y = y0 if dx == 0 else y0 + round((x - x0) * dy / float(dx))
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
elif y0 < y1:
|
elif y0 < y1:
|
||||||
for y in range(y0, y1 + 1):
|
for y in range(y0, y1 + 1):
|
||||||
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
|
x = x0 if dy == 0 else x0 + round((y - y0) * dx / float(dy))
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
else:
|
else:
|
||||||
for y in range(y1, y0 + 1):
|
for y in range(y1, y0 + 1):
|
||||||
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
|
x = x0 if dy == 0 else x1 + round((y - y1) * dx / float(dy))
|
||||||
self.point(x, y, char)
|
self.point(x, y, char)
|
||||||
|
|
||||||
def text(self, x: int, y: int, text: str) -> None:
|
def text(self, x: int, y: int, text: str) -> None:
|
||||||
@ -291,8 +291,8 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|||||||
maxx = max(xlist)
|
maxx = max(xlist)
|
||||||
maxy = max(ylist)
|
maxy = max(ylist)
|
||||||
|
|
||||||
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
canvas_cols = math.ceil(math.ceil(maxx) - math.floor(minx)) + 1
|
||||||
canvas_lines = int(round(maxy - miny))
|
canvas_lines = round(maxy - miny)
|
||||||
|
|
||||||
canvas = AsciiCanvas(canvas_cols, canvas_lines)
|
canvas = AsciiCanvas(canvas_cols, canvas_lines)
|
||||||
|
|
||||||
@ -305,10 +305,10 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|||||||
start = edge.view.pts[index - 1]
|
start = edge.view.pts[index - 1]
|
||||||
end = edge.view.pts[index]
|
end = edge.view.pts[index]
|
||||||
|
|
||||||
start_x = int(round(start[0] - minx))
|
start_x = round(start[0] - minx)
|
||||||
start_y = int(round(start[1] - miny))
|
start_y = round(start[1] - miny)
|
||||||
end_x = int(round(end[0] - minx))
|
end_x = round(end[0] - minx)
|
||||||
end_y = int(round(end[1] - miny))
|
end_y = round(end[1] - miny)
|
||||||
|
|
||||||
if start_x < 0 or start_y < 0 or end_x < 0 or end_y < 0:
|
if start_x < 0 or start_y < 0 or end_x < 0 or end_y < 0:
|
||||||
msg = (
|
msg = (
|
||||||
@ -328,12 +328,12 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
|||||||
y = vertex.view.xy[1]
|
y = vertex.view.xy[1]
|
||||||
|
|
||||||
canvas.box(
|
canvas.box(
|
||||||
int(round(x - minx)),
|
round(x - minx),
|
||||||
int(round(y - miny)),
|
round(y - miny),
|
||||||
vertex.view.w,
|
vertex.view.w,
|
||||||
vertex.view.h,
|
vertex.view.h,
|
||||||
)
|
)
|
||||||
|
|
||||||
canvas.text(int(round(x - minx)) + 1, int(round(y - miny)) + 1, vertex.data)
|
canvas.text(round(x - minx) + 1, round(y - miny) + 1, vertex.data)
|
||||||
|
|
||||||
return canvas.draw()
|
return canvas.draw()
|
||||||
|
@ -442,7 +442,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
if self.input_messages_key:
|
if self.input_messages_key:
|
||||||
key = self.input_messages_key
|
key = self.input_messages_key
|
||||||
elif len(input_val) == 1:
|
elif len(input_val) == 1:
|
||||||
key = list(input_val.keys())[0]
|
key = next(iter(input_val.keys()))
|
||||||
else:
|
else:
|
||||||
key = "input"
|
key = "input"
|
||||||
input_val = input_val[key]
|
input_val = input_val[key]
|
||||||
@ -472,7 +472,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
||||||
f"Got {input_val}."
|
f"Got {input_val}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
def _get_output_messages(
|
def _get_output_messages(
|
||||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||||
@ -484,7 +484,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
if self.output_messages_key:
|
if self.output_messages_key:
|
||||||
key = self.output_messages_key
|
key = self.output_messages_key
|
||||||
elif len(output_val) == 1:
|
elif len(output_val) == 1:
|
||||||
key = list(output_val.keys())[0]
|
key = next(iter(output_val.keys()))
|
||||||
else:
|
else:
|
||||||
key = "output"
|
key = "output"
|
||||||
# If you are wrapping a chat model directly
|
# If you are wrapping a chat model directly
|
||||||
@ -507,7 +507,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
|||||||
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
||||||
f"Got {output_val}."
|
f"Got {output_val}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg) # noqa: TRY004
|
raise ValueError(msg)
|
||||||
|
|
||||||
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||||
|
@ -699,7 +699,7 @@ def get_unique_config_specs(
|
|||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"RunnableSequence contains conflicting config specs"
|
"RunnableSequence contains conflicting config specs"
|
||||||
f"for {id}: {[first] + others}"
|
f"for {id}: {[first, *others]}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return unique
|
return unique
|
||||||
@ -772,9 +772,8 @@ def is_async_generator(
|
|||||||
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
||||||
an async generator, False otherwise.
|
an async generator, False otherwise.
|
||||||
"""
|
"""
|
||||||
return (
|
return inspect.isasyncgenfunction(func) or (
|
||||||
inspect.isasyncgenfunction(func)
|
hasattr(func, "__call__") # noqa: B004
|
||||||
or hasattr(func, "__call__") # noqa: B004
|
|
||||||
and inspect.isasyncgenfunction(func.__call__)
|
and inspect.isasyncgenfunction(func.__call__)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -791,8 +790,7 @@ def is_async_callable(
|
|||||||
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
||||||
False otherwise.
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
return (
|
return asyncio.iscoroutinefunction(func) or (
|
||||||
asyncio.iscoroutinefunction(func)
|
hasattr(func, "__call__") # noqa: B004
|
||||||
or hasattr(func, "__call__") # noqa: B004
|
|
||||||
and asyncio.iscoroutinefunction(func.__call__)
|
and asyncio.iscoroutinefunction(func.__call__)
|
||||||
)
|
)
|
||||||
|
@ -67,7 +67,7 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
|
|||||||
for pkg in reversed(order_by):
|
for pkg in reversed(order_by):
|
||||||
if pkg in all_packages:
|
if pkg in all_packages:
|
||||||
all_packages.remove(pkg)
|
all_packages.remove(pkg)
|
||||||
all_packages = [pkg] + list(all_packages)
|
all_packages = [pkg, *list(all_packages)]
|
||||||
|
|
||||||
system_info = {
|
system_info = {
|
||||||
"OS": platform.system(),
|
"OS": platform.system(),
|
||||||
|
@ -53,25 +53,25 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
"FILTERED_ARGS",
|
||||||
"ArgsSchema",
|
"ArgsSchema",
|
||||||
"BaseTool",
|
"BaseTool",
|
||||||
"BaseToolkit",
|
"BaseToolkit",
|
||||||
"FILTERED_ARGS",
|
|
||||||
"SchemaAnnotationError",
|
|
||||||
"ToolException",
|
|
||||||
"InjectedToolArg",
|
"InjectedToolArg",
|
||||||
"InjectedToolCallId",
|
"InjectedToolCallId",
|
||||||
"_get_runnable_config_param",
|
"RetrieverInput",
|
||||||
"create_schema_from_function",
|
"SchemaAnnotationError",
|
||||||
"convert_runnable_to_tool",
|
"StructuredTool",
|
||||||
"tool",
|
"Tool",
|
||||||
|
"ToolException",
|
||||||
"ToolsRenderer",
|
"ToolsRenderer",
|
||||||
|
"_get_runnable_config_param",
|
||||||
|
"convert_runnable_to_tool",
|
||||||
|
"create_retriever_tool",
|
||||||
|
"create_schema_from_function",
|
||||||
"render_text_description",
|
"render_text_description",
|
||||||
"render_text_description_and_args",
|
"render_text_description_and_args",
|
||||||
"RetrieverInput",
|
"tool",
|
||||||
"create_retriever_tool",
|
|
||||||
"Tool",
|
|
||||||
"StructuredTool",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -273,7 +273,7 @@ def create_schema_from_function(
|
|||||||
# Handle classmethods and instance methods
|
# Handle classmethods and instance methods
|
||||||
existing_params: list[str] = list(sig.parameters.keys())
|
existing_params: list[str] = list(sig.parameters.keys())
|
||||||
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
||||||
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
|
filter_args_ = [existing_params[0], *list(FILTERED_ARGS)]
|
||||||
else:
|
else:
|
||||||
filter_args_ = list(FILTERED_ARGS)
|
filter_args_ = list(FILTERED_ARGS)
|
||||||
|
|
||||||
@ -991,10 +991,8 @@ def _format_output(
|
|||||||
|
|
||||||
def _is_message_content_type(obj: Any) -> bool:
|
def _is_message_content_type(obj: Any) -> bool:
|
||||||
"""Check for OpenAI or Anthropic format tool message content."""
|
"""Check for OpenAI or Anthropic format tool message content."""
|
||||||
return (
|
return isinstance(obj, str) or (
|
||||||
isinstance(obj, str)
|
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
||||||
or isinstance(obj, list)
|
|
||||||
and all(_is_message_content_block(e) for e in obj)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ def tool(
|
|||||||
monkey: The baz.
|
monkey: The baz.
|
||||||
\"\"\"
|
\"\"\"
|
||||||
return bar
|
return bar
|
||||||
""" # noqa: D214,D405,D410,D411,D412,D416
|
""" # noqa: D214, D410, D411
|
||||||
|
|
||||||
def _create_tool_factory(
|
def _create_tool_factory(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
@ -367,7 +367,7 @@ def _get_schema_from_runnable_and_arg_types(
|
|||||||
msg = (
|
msg = (
|
||||||
"Tool input must be str or dict. If dict, dict arguments must be "
|
"Tool input must be str or dict. If dict, dict arguments must be "
|
||||||
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
||||||
f"arg_types into `.as_tool` to specify. {str(e)}"
|
f"arg_types into `.as_tool` to specify. {e}"
|
||||||
)
|
)
|
||||||
raise TypeError(msg) from e
|
raise TypeError(msg) from e
|
||||||
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
||||||
|
@ -26,13 +26,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseTracer",
|
"BaseTracer",
|
||||||
|
"ConsoleCallbackHandler",
|
||||||
"EvaluatorCallbackHandler",
|
"EvaluatorCallbackHandler",
|
||||||
"LangChainTracer",
|
"LangChainTracer",
|
||||||
"ConsoleCallbackHandler",
|
"LogStreamCallbackHandler",
|
||||||
"Run",
|
"Run",
|
||||||
"RunLog",
|
"RunLog",
|
||||||
"RunLogPatch",
|
"RunLogPatch",
|
||||||
"LogStreamCallbackHandler",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -955,7 +955,7 @@ async def _astream_events_implementation_v2(
|
|||||||
if callbacks is None:
|
if callbacks is None:
|
||||||
config["callbacks"] = [event_streamer]
|
config["callbacks"] = [event_streamer]
|
||||||
elif isinstance(callbacks, list):
|
elif isinstance(callbacks, list):
|
||||||
config["callbacks"] = callbacks + [event_streamer]
|
config["callbacks"] = [*callbacks, event_streamer]
|
||||||
elif isinstance(callbacks, BaseCallbackManager):
|
elif isinstance(callbacks, BaseCallbackManager):
|
||||||
callbacks = callbacks.copy()
|
callbacks = callbacks.copy()
|
||||||
callbacks.add_handler(event_streamer, inherit=True)
|
callbacks.add_handler(event_streamer, inherit=True)
|
||||||
|
@ -632,7 +632,7 @@ async def _astream_log_implementation(
|
|||||||
if callbacks is None:
|
if callbacks is None:
|
||||||
config["callbacks"] = [stream]
|
config["callbacks"] = [stream]
|
||||||
elif isinstance(callbacks, list):
|
elif isinstance(callbacks, list):
|
||||||
config["callbacks"] = callbacks + [stream]
|
config["callbacks"] = [*callbacks, stream]
|
||||||
elif isinstance(callbacks, BaseCallbackManager):
|
elif isinstance(callbacks, BaseCallbackManager):
|
||||||
callbacks = callbacks.copy()
|
callbacks = callbacks.copy()
|
||||||
callbacks.add_handler(stream, inherit=True)
|
callbacks.add_handler(stream, inherit=True)
|
||||||
|
@ -95,9 +95,7 @@ class FunctionCallbackHandler(BaseTracer):
|
|||||||
parents = self.get_parents(run)[::-1]
|
parents = self.get_parents(run)[::-1]
|
||||||
return " > ".join(
|
return " > ".join(
|
||||||
f"{parent.run_type}:{parent.name}"
|
f"{parent.run_type}:{parent.name}"
|
||||||
if i != len(parents) - 1
|
for i, parent in enumerate([*parents, run])
|
||||||
else f"{parent.run_type}:{parent.name}"
|
|
||||||
for i, parent in enumerate(parents + [run])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# logging methods
|
# logging methods
|
||||||
|
@ -38,32 +38,32 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"build_extra_kwargs",
|
|
||||||
"StrictFormatter",
|
"StrictFormatter",
|
||||||
|
"abatch_iterate",
|
||||||
|
"batch_iterate",
|
||||||
|
"build_extra_kwargs",
|
||||||
"check_package_version",
|
"check_package_version",
|
||||||
|
"comma_list",
|
||||||
"convert_to_secret_str",
|
"convert_to_secret_str",
|
||||||
"formatter",
|
"formatter",
|
||||||
|
"from_env",
|
||||||
"get_bolded_text",
|
"get_bolded_text",
|
||||||
"get_color_mapping",
|
"get_color_mapping",
|
||||||
"get_colored_text",
|
"get_colored_text",
|
||||||
|
"get_from_dict_or_env",
|
||||||
|
"get_from_env",
|
||||||
"get_pydantic_field_names",
|
"get_pydantic_field_names",
|
||||||
"guard_import",
|
"guard_import",
|
||||||
|
"image",
|
||||||
"mock_now",
|
"mock_now",
|
||||||
|
"pre_init",
|
||||||
"print_text",
|
"print_text",
|
||||||
"raise_for_status_with_text",
|
"raise_for_status_with_text",
|
||||||
"xor_args",
|
|
||||||
"try_load_from_hub",
|
|
||||||
"image",
|
|
||||||
"get_from_env",
|
|
||||||
"get_from_dict_or_env",
|
|
||||||
"stringify_dict",
|
|
||||||
"comma_list",
|
|
||||||
"stringify_value",
|
|
||||||
"pre_init",
|
|
||||||
"batch_iterate",
|
|
||||||
"abatch_iterate",
|
|
||||||
"from_env",
|
|
||||||
"secret_from_env",
|
"secret_from_env",
|
||||||
|
"stringify_dict",
|
||||||
|
"stringify_value",
|
||||||
|
"try_load_from_hub",
|
||||||
|
"xor_args",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -31,7 +31,9 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
|||||||
merged = left.copy()
|
merged = left.copy()
|
||||||
for right in others:
|
for right in others:
|
||||||
for right_k, right_v in right.items():
|
for right_k, right_v in right.items():
|
||||||
if right_k not in merged or right_v is not None and merged[right_k] is None:
|
if right_k not in merged or (
|
||||||
|
right_v is not None and merged[right_k] is None
|
||||||
|
):
|
||||||
merged[right_k] = right_v
|
merged[right_k] = right_v
|
||||||
elif right_v is None:
|
elif right_v is None:
|
||||||
continue
|
continue
|
||||||
|
@ -144,7 +144,7 @@ async def tee_peer(
|
|||||||
yield buffer.popleft()
|
yield buffer.popleft()
|
||||||
finally:
|
finally:
|
||||||
async with lock:
|
async with lock:
|
||||||
# this peer is done – remove its buffer
|
# this peer is done - remove its buffer
|
||||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||||
if peer_buffer is buffer:
|
if peer_buffer is buffer:
|
||||||
peers.pop(idx)
|
peers.pop(idx)
|
||||||
|
@ -42,8 +42,8 @@ def get_from_dict_or_env(
|
|||||||
"""
|
"""
|
||||||
if isinstance(key, (list, tuple)):
|
if isinstance(key, (list, tuple)):
|
||||||
for k in key:
|
for k in key:
|
||||||
if k in data and data[k]:
|
if value := data.get(k):
|
||||||
return data[k]
|
return value
|
||||||
|
|
||||||
if isinstance(key, str) and key in data and data[key]:
|
if isinstance(key, str) and key in data and data[key]:
|
||||||
return data[key]
|
return data[key]
|
||||||
@ -70,8 +70,8 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
|||||||
ValueError: If the key is not in the dictionary and no default value is
|
ValueError: If the key is not in the dictionary and no default value is
|
||||||
provided or if the environment variable is not set.
|
provided or if the environment variable is not set.
|
||||||
"""
|
"""
|
||||||
if env_key in os.environ and os.environ[env_key]:
|
if env_value := os.getenv(env_key):
|
||||||
return os.environ[env_key]
|
return env_value
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -81,7 +81,7 @@ def tee_peer(
|
|||||||
yield buffer.popleft()
|
yield buffer.popleft()
|
||||||
finally:
|
finally:
|
||||||
with lock:
|
with lock:
|
||||||
# this peer is done – remove its buffer
|
# this peer is done - remove its buffer
|
||||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||||
if peer_buffer is buffer:
|
if peer_buffer is buffer:
|
||||||
peers.pop(idx)
|
peers.pop(idx)
|
||||||
|
@ -571,7 +571,7 @@ def render(
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
def_ldel=def_ldel,
|
def_ldel=def_ldel,
|
||||||
def_rdel=def_rdel,
|
def_rdel=def_rdel,
|
||||||
scopes=data and [data] + scopes or scopes,
|
scopes=(data and [data, *scopes]) or scopes,
|
||||||
warn=warn,
|
warn=warn,
|
||||||
keep=keep,
|
keep=keep,
|
||||||
),
|
),
|
||||||
@ -601,7 +601,7 @@ def render(
|
|||||||
# For every item in the scope
|
# For every item in the scope
|
||||||
for thing in scope:
|
for thing in scope:
|
||||||
# Append it as the most recent scope and render
|
# Append it as the most recent scope and render
|
||||||
new_scope = [thing] + scopes
|
new_scope = [thing, *scopes]
|
||||||
rend = render(
|
rend = render(
|
||||||
template=tags,
|
template=tags,
|
||||||
scopes=new_scope,
|
scopes=new_scope,
|
||||||
|
@ -9,10 +9,10 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
|
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"VectorStore",
|
|
||||||
"VST",
|
"VST",
|
||||||
"VectorStoreRetriever",
|
|
||||||
"InMemoryVectorStore",
|
"InMemoryVectorStore",
|
||||||
|
"VectorStore",
|
||||||
|
"VectorStoreRetriever",
|
||||||
)
|
)
|
||||||
|
|
||||||
_dynamic_imports = {
|
_dynamic_imports = {
|
||||||
|
@ -995,7 +995,7 @@ class VectorStore(ABC):
|
|||||||
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
|
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
|
tags = kwargs.pop("tags", None) or [*self._get_retriever_tags()]
|
||||||
return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)
|
return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
|
|
||||||
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
|
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
|
||||||
|
|
||||||
""" # noqa: E501
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding: Embeddings) -> None:
|
def __init__(self, embedding: Embeddings) -> None:
|
||||||
"""Initialize with the given embedding function.
|
"""Initialize with the given embedding function.
|
||||||
|
@ -91,6 +91,7 @@ ignore = [
|
|||||||
"ISC001", # Messes with the formatter
|
"ISC001", # Messes with the formatter
|
||||||
"PERF203", # Rarely useful
|
"PERF203", # Rarely useful
|
||||||
"PLR09", # Too many something (arg, statements, etc)
|
"PLR09", # Too many something (arg, statements, etc)
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
"TC001", # Doesn't play well with Pydantic
|
"TC001", # Doesn't play well with Pydantic
|
||||||
"TC002", # Doesn't play well with Pydantic
|
"TC002", # Doesn't play well with Pydantic
|
||||||
"TC003", # Doesn't play well with Pydantic
|
"TC003", # Doesn't play well with Pydantic
|
||||||
@ -104,7 +105,6 @@ ignore = [
|
|||||||
"BLE",
|
"BLE",
|
||||||
"ERA",
|
"ERA",
|
||||||
"PLR2004",
|
"PLR2004",
|
||||||
"RUF",
|
|
||||||
]
|
]
|
||||||
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||||
flake8-annotations.allow-star-arg-any = true
|
flake8-annotations.allow-star-arg-any = true
|
||||||
|
@ -12,8 +12,7 @@ if __name__ == "__main__":
|
|||||||
for file in files:
|
for file in files:
|
||||||
try:
|
try:
|
||||||
module_name = "".join(
|
module_name = "".join(
|
||||||
random.choice(string.ascii_letters)
|
random.choice(string.ascii_letters) for _ in range(20)
|
||||||
for _ in range(20) # noqa: S311
|
|
||||||
)
|
)
|
||||||
SourceFileLoader(module_name, file).load_module()
|
SourceFileLoader(module_name, file).load_module()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -37,7 +37,7 @@ def test_selector_add_example(selector: LengthBasedExampleSelector) -> None:
|
|||||||
selector.add_example(new_example)
|
selector.add_example(new_example)
|
||||||
short_question = "Short question?"
|
short_question = "Short question?"
|
||||||
output = selector.select_examples({"question": short_question})
|
output = selector.select_examples({"question": short_question})
|
||||||
assert output == EXAMPLES + [new_example]
|
assert output == [*EXAMPLES, new_example]
|
||||||
|
|
||||||
|
|
||||||
def test_selector_trims_one_example(selector: LengthBasedExampleSelector) -> None:
|
def test_selector_trims_one_example(selector: LengthBasedExampleSelector) -> None:
|
||||||
|
@ -3,7 +3,7 @@ from langchain_core.indexing import __all__
|
|||||||
|
|
||||||
def test_all() -> None:
|
def test_all() -> None:
|
||||||
"""Use to catch obvious breaking changes."""
|
"""Use to catch obvious breaking changes."""
|
||||||
assert list(__all__) == sorted(__all__, key=str.lower)
|
assert list(__all__) == sorted(__all__, key=str)
|
||||||
assert set(__all__) == {
|
assert set(__all__) == {
|
||||||
"aindex",
|
"aindex",
|
||||||
"DeleteResponse",
|
"DeleteResponse",
|
||||||
|
@ -191,14 +191,14 @@ def test_format_instructions_preserves_language() -> None:
|
|||||||
|
|
||||||
description = (
|
description = (
|
||||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
||||||
)
|
)
|
||||||
|
|
||||||
class Foo(BaseModel):
|
class Foo(BaseModel):
|
||||||
hello: str = Field(
|
hello: str = Field(
|
||||||
description=(
|
description=(
|
||||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ def test_chat_prompt_template_with_messages(
|
|||||||
messages: list[BaseMessagePromptTemplate],
|
messages: list[BaseMessagePromptTemplate],
|
||||||
) -> None:
|
) -> None:
|
||||||
chat_prompt_template = ChatPromptTemplate.from_messages(
|
chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||||
messages + [HumanMessage(content="foo")]
|
[*messages, HumanMessage(content="foo")]
|
||||||
)
|
)
|
||||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||||
[
|
[
|
||||||
|
@ -168,7 +168,7 @@ class FakeTracer(BaseTracer):
|
|||||||
self.runs.append(self._copy_run(run))
|
self.runs.append(self._copy_run(run))
|
||||||
|
|
||||||
def flattened_runs(self) -> list[Run]:
|
def flattened_runs(self) -> list[Run]:
|
||||||
q = [] + self.runs
|
q = [*self.runs]
|
||||||
result = []
|
result = []
|
||||||
while q:
|
while q:
|
||||||
parent = q.pop()
|
parent = q.pop()
|
||||||
|
@ -2312,7 +2312,7 @@ def test_injected_arg_with_complex_type() -> None:
|
|||||||
self.value = "bar"
|
self.value = "bar"
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
|
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
|
||||||
"""Tool that has an injected tool arg."""
|
"""Tool that has an injected tool arg."""
|
||||||
return foo.value
|
return foo.value
|
||||||
|
|
||||||
@ -2488,7 +2488,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
|||||||
|
|
||||||
def test_empty_string_tool_call_id() -> None:
|
def test_empty_string_tool_call_id() -> None:
|
||||||
@tool
|
@tool
|
||||||
def foo(x: int) -> str: # noqa: ARG001
|
def foo(x: int) -> str:
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2500,7 +2500,7 @@ def test_empty_string_tool_call_id() -> None:
|
|||||||
def test_tool_decorator_description() -> None:
|
def test_tool_decorator_description() -> None:
|
||||||
# test basic tool
|
# test basic tool
|
||||||
@tool
|
@tool
|
||||||
def foo(x: int) -> str: # noqa: ARG001
|
def foo(x: int) -> str:
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2512,7 +2512,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
|
|
||||||
# test basic tool with description
|
# test basic tool with description
|
||||||
@tool(description="description")
|
@tool(description="description")
|
||||||
def foo_description(x: int) -> str: # noqa: ARG001
|
def foo_description(x: int) -> str:
|
||||||
"""Foo."""
|
"""Foo."""
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@ -2531,7 +2531,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
x: int
|
x: int
|
||||||
|
|
||||||
@tool(args_schema=ArgsSchema)
|
@tool(args_schema=ArgsSchema)
|
||||||
def foo_args_schema(x: int) -> str: # noqa: ARG001
|
def foo_args_schema(x: int) -> str:
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_schema.description == "Bar."
|
assert foo_args_schema.description == "Bar."
|
||||||
@ -2543,7 +2543,7 @@ def test_tool_decorator_description() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@tool(description="description", args_schema=ArgsSchema)
|
@tool(description="description", args_schema=ArgsSchema)
|
||||||
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
|
def foo_args_schema_description(x: int) -> str:
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_schema_description.description == "description"
|
assert foo_args_schema_description.description == "description"
|
||||||
@ -2565,11 +2565,11 @@ def test_tool_decorator_description() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@tool(args_schema=args_json_schema)
|
@tool(args_schema=args_json_schema)
|
||||||
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
|
def foo_args_jsons_schema(x: int) -> str:
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
@tool(description="description", args_schema=args_json_schema)
|
@tool(description="description", args_schema=args_json_schema)
|
||||||
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
|
def foo_args_jsons_schema_with_description(x: int) -> str:
|
||||||
return "hi"
|
return "hi"
|
||||||
|
|
||||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||||
@ -2629,10 +2629,10 @@ def test_title_property_preserved() -> None:
|
|||||||
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||||
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
||||||
|
|
||||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
def sync_no_op(foo: int) -> str:
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
async def async_no_op(foo: int) -> str:
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
@ -2677,10 +2677,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
|||||||
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
||||||
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
||||||
|
|
||||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
def sync_no_op(foo: int) -> str:
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
async def async_no_op(foo: int) -> str:
|
||||||
return "good"
|
return "good"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
|
@ -39,7 +39,7 @@ async def test_same_event_loop() -> None:
|
|||||||
**item,
|
**item,
|
||||||
}
|
}
|
||||||
|
|
||||||
asyncio.create_task(producer())
|
producer_task = asyncio.create_task(producer())
|
||||||
|
|
||||||
items = [item async for item in consumer()]
|
items = [item async for item in consumer()]
|
||||||
|
|
||||||
@ -57,6 +57,8 @@ async def test_same_event_loop() -> None:
|
|||||||
f"delta_time: {delta_time}"
|
f"delta_time: {delta_time}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await producer_task
|
||||||
|
|
||||||
|
|
||||||
async def test_queue_for_streaming_via_sync_call() -> None:
|
async def test_queue_for_streaming_via_sync_call() -> None:
|
||||||
"""Test via async -> sync -> async path."""
|
"""Test via async -> sync -> async path."""
|
||||||
|
Loading…
Reference in New Issue
Block a user