mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
core: Add ruff rules RUF (#29353)
See https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf Mostly: * [RUF022](https://docs.astral.sh/ruff/rules/unsorted-dunder-all/) (unsorted `__all__`) * [RUF100](https://docs.astral.sh/ruff/rules/unused-noqa/) (unused noqa) * [RUF021](https://docs.astral.sh/ruff/rules/parenthesize-chained-operators/) (parenthesize-chained-operators) * [RUF015](https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element/) (unnecessary-iterable-allocation-for-first-element) * [RUF005](https://docs.astral.sh/ruff/rules/collection-literal-concatenation/) (collection-literal-concatenation) * [RUF046](https://docs.astral.sh/ruff/rules/unnecessary-cast-to-int/) (unnecessary-cast-to-int) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
6cd1aadf60
commit
a8f2ddee31
@ -30,15 +30,15 @@ if TYPE_CHECKING:
|
||||
from .path import as_import_path, get_relative_path
|
||||
|
||||
__all__ = (
|
||||
"LangChainBetaWarning",
|
||||
"LangChainDeprecationWarning",
|
||||
"as_import_path",
|
||||
"beta",
|
||||
"deprecated",
|
||||
"get_relative_path",
|
||||
"LangChainBetaWarning",
|
||||
"LangChainDeprecationWarning",
|
||||
"suppress_langchain_beta_warning",
|
||||
"surface_langchain_beta_warnings",
|
||||
"suppress_langchain_deprecation_warning",
|
||||
"surface_langchain_beta_warnings",
|
||||
"surface_langchain_deprecation_warnings",
|
||||
"warn_deprecated",
|
||||
)
|
||||
|
@ -54,39 +54,39 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"dispatch_custom_event",
|
||||
"adispatch_custom_event",
|
||||
"RetrieverManagerMixin",
|
||||
"LLMManagerMixin",
|
||||
"ChainManagerMixin",
|
||||
"ToolManagerMixin",
|
||||
"Callbacks",
|
||||
"CallbackManagerMixin",
|
||||
"RunManagerMixin",
|
||||
"BaseCallbackHandler",
|
||||
"AsyncCallbackHandler",
|
||||
"BaseCallbackManager",
|
||||
"BaseRunManager",
|
||||
"RunManager",
|
||||
"ParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"AsyncParentRunManager",
|
||||
"CallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"CallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"AsyncCallbackManager",
|
||||
"AsyncCallbackManagerForChainGroup",
|
||||
"AsyncCallbackManagerForChainRun",
|
||||
"AsyncCallbackManagerForLLMRun",
|
||||
"AsyncCallbackManagerForRetrieverRun",
|
||||
"AsyncCallbackManagerForToolRun",
|
||||
"AsyncParentRunManager",
|
||||
"AsyncRunManager",
|
||||
"BaseCallbackHandler",
|
||||
"BaseCallbackManager",
|
||||
"BaseRunManager",
|
||||
"CallbackManager",
|
||||
"CallbackManagerForChainGroup",
|
||||
"CallbackManagerForChainRun",
|
||||
"CallbackManagerForLLMRun",
|
||||
"CallbackManagerForRetrieverRun",
|
||||
"CallbackManagerForToolRun",
|
||||
"CallbackManagerMixin",
|
||||
"Callbacks",
|
||||
"ChainManagerMixin",
|
||||
"FileCallbackHandler",
|
||||
"LLMManagerMixin",
|
||||
"ParentRunManager",
|
||||
"RetrieverManagerMixin",
|
||||
"RunManager",
|
||||
"RunManagerMixin",
|
||||
"StdOutCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"FileCallbackHandler",
|
||||
"ToolManagerMixin",
|
||||
"UsageMetadataCallbackHandler",
|
||||
"adispatch_custom_event",
|
||||
"dispatch_custom_event",
|
||||
"get_usage_metadata_callback",
|
||||
)
|
||||
|
||||
|
@ -14,8 +14,8 @@ __all__ = (
|
||||
"BaseLoader",
|
||||
"Blob",
|
||||
"BlobLoader",
|
||||
"PathLike",
|
||||
"LangSmithLoader",
|
||||
"PathLike",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
from .compressor import BaseDocumentCompressor
|
||||
from .transformers import BaseDocumentTransformer
|
||||
|
||||
__all__ = ("Document", "BaseDocumentTransformer", "BaseDocumentCompressor")
|
||||
__all__ = ("BaseDocumentCompressor", "BaseDocumentTransformer", "Document")
|
||||
|
||||
_dynamic_imports = {
|
||||
"Document": "base",
|
||||
|
@ -20,14 +20,14 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"aindex",
|
||||
"DeleteResponse",
|
||||
"DocumentIndex",
|
||||
"index",
|
||||
"IndexingResult",
|
||||
"InMemoryRecordManager",
|
||||
"IndexingResult",
|
||||
"RecordManager",
|
||||
"UpsertResponse",
|
||||
"aindex",
|
||||
"index",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -68,22 +68,22 @@ if TYPE_CHECKING:
|
||||
from langchain_core.language_models.llms import LLM, BaseLLM
|
||||
|
||||
__all__ = (
|
||||
"BaseLanguageModel",
|
||||
"BaseChatModel",
|
||||
"SimpleChatModel",
|
||||
"BaseLLM",
|
||||
"LLM",
|
||||
"LanguageModelInput",
|
||||
"get_tokenizer",
|
||||
"LangSmithParams",
|
||||
"LanguageModelOutput",
|
||||
"LanguageModelLike",
|
||||
"FakeListLLM",
|
||||
"FakeStreamingListLLM",
|
||||
"BaseChatModel",
|
||||
"BaseLLM",
|
||||
"BaseLanguageModel",
|
||||
"FakeListChatModel",
|
||||
"FakeListLLM",
|
||||
"FakeMessagesListChatModel",
|
||||
"FakeStreamingListLLM",
|
||||
"GenericFakeChatModel",
|
||||
"LangSmithParams",
|
||||
"LanguageModelInput",
|
||||
"LanguageModelLike",
|
||||
"LanguageModelOutput",
|
||||
"ParrotFakeChatModel",
|
||||
"SimpleChatModel",
|
||||
"get_tokenizer",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -23,7 +23,7 @@ def _is_openai_data_block(block: dict) -> bool:
|
||||
if isinstance(file_data, str):
|
||||
return True
|
||||
|
||||
elif block.get("type") == "input_audio": # noqa: SIM102
|
||||
elif block.get("type") == "input_audio":
|
||||
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
|
||||
audio_data = input_audio.get("data")
|
||||
audio_format = input_audio.get("format")
|
||||
|
@ -354,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
@override
|
||||
def invoke(
|
||||
@ -1203,7 +1203,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
msg = "Unexpected generation type"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
async def _call_async(
|
||||
self,
|
||||
@ -1219,7 +1219,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if isinstance(generation, ChatGeneration):
|
||||
return generation.message
|
||||
msg = "Unexpected generation type"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
def call_as_llm(
|
||||
@ -1261,7 +1261,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
msg = "Cannot use predict when output is not a string."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
@deprecated("0.1.7", alternative="invoke", removal="1.0")
|
||||
@override
|
||||
@ -1287,7 +1287,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if isinstance(result.content, str):
|
||||
return result.content
|
||||
msg = "Cannot use predict when output is not a string."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
|
||||
@override
|
||||
|
@ -103,7 +103,9 @@ def create_base_retry_decorator(
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(coro)
|
||||
# TODO: Fix RUF006 - this task should have a reference
|
||||
# and be awaited somewhere
|
||||
loop.create_task(coro) # noqa: RUF006
|
||||
else:
|
||||
asyncio.run(coro)
|
||||
except Exception as e:
|
||||
@ -336,7 +338,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
f"Invalid input type {type(input)}. "
|
||||
"Must be a PromptValue, str, or list of BaseMessages."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
def _get_ls_params(
|
||||
self,
|
||||
|
@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
# the `from langchain_core.load.load import load` absolute import should also work.
|
||||
from langchain_core.load.load import load
|
||||
|
||||
__all__ = ("dumpd", "dumps", "load", "loads", "Serializable")
|
||||
__all__ = ("Serializable", "dumpd", "dumps", "load", "loads")
|
||||
|
||||
_dynamic_imports = {
|
||||
"dumpd": "dump",
|
||||
|
@ -3,7 +3,7 @@
|
||||
This module contains memory abstractions from LangChain v0.0.x.
|
||||
|
||||
These abstractions are now deprecated and will be removed in LangChain v1.0.0.
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
@ -76,28 +76,28 @@ __all__ = (
|
||||
"HumanMessageChunk",
|
||||
"InvalidToolCall",
|
||||
"MessageLikeRepresentation",
|
||||
"RemoveMessage",
|
||||
"SystemMessage",
|
||||
"SystemMessageChunk",
|
||||
"ToolCall",
|
||||
"ToolCallChunk",
|
||||
"ToolMessage",
|
||||
"ToolMessageChunk",
|
||||
"RemoveMessage",
|
||||
"_message_from_dict",
|
||||
"convert_to_messages",
|
||||
"convert_to_openai_data_block",
|
||||
"convert_to_openai_image_block",
|
||||
"convert_to_messages",
|
||||
"convert_to_openai_messages",
|
||||
"filter_messages",
|
||||
"get_buffer_string",
|
||||
"is_data_content_block",
|
||||
"merge_content",
|
||||
"merge_message_runs",
|
||||
"message_chunk_to_message",
|
||||
"message_to_dict",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"filter_messages",
|
||||
"merge_message_runs",
|
||||
"trim_messages",
|
||||
"convert_to_openai_messages",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -423,13 +423,13 @@ def add_ai_message_chunks(
|
||||
|
||||
id = None
|
||||
candidates = [left.id] + [o.id for o in others]
|
||||
# first pass: pick the first non‐run-* id
|
||||
# first pass: pick the first non-run-* id
|
||||
for id_ in candidates:
|
||||
if id_ and not id_.startswith(_LC_ID_PREFIX):
|
||||
id = id_
|
||||
break
|
||||
else:
|
||||
# second pass: no provider-assigned id found, just take the first non‐null
|
||||
# second pass: no provider-assigned id found, just take the first non-null
|
||||
for id_ in candidates:
|
||||
if id_:
|
||||
id = id_
|
||||
|
@ -101,8 +101,7 @@ class BaseMessage(Serializable):
|
||||
block
|
||||
for block in self.content
|
||||
if isinstance(block, str)
|
||||
or block.get("type") == "text"
|
||||
and isinstance(block.get("text"), str)
|
||||
or (block.get("type") == "text" and isinstance(block.get("text"), str))
|
||||
]
|
||||
return "".join(
|
||||
block if isinstance(block, str) else block["text"] for block in blocks
|
||||
@ -161,7 +160,7 @@ def merge_content(
|
||||
merged += content
|
||||
# If the next chunk is a list, add the current to the start of the list
|
||||
else:
|
||||
merged = [merged] + content # type: ignore[assignment,operator]
|
||||
merged = [merged, *content]
|
||||
elif isinstance(content, list):
|
||||
# If both are lists
|
||||
merged = merge_lists(cast("list", merged), content) # type: ignore[assignment]
|
||||
|
@ -885,7 +885,7 @@ def trim_messages(
|
||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
||||
elif callable(token_counter):
|
||||
if (
|
||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||
next(iter(inspect.signature(token_counter).parameters.values())).annotation
|
||||
is BaseMessage
|
||||
):
|
||||
|
||||
@ -1460,7 +1460,7 @@ def _last_max_tokens(
|
||||
# Re-reverse the messages and add back the system message if needed
|
||||
result = reversed_result[::-1]
|
||||
if system_message:
|
||||
result = [system_message] + result
|
||||
result = [system_message, *result]
|
||||
|
||||
return result
|
||||
|
||||
@ -1543,7 +1543,7 @@ def _get_message_openai_role(message: BaseMessage) -> str:
|
||||
if isinstance(message, ChatMessage):
|
||||
return message.role
|
||||
msg = f"Unknown BaseMessage type {message.__class__}."
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:
|
||||
|
@ -47,23 +47,23 @@ if TYPE_CHECKING:
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
|
||||
__all__ = [
|
||||
"BaseLLMOutputParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseOutputParser",
|
||||
"ListOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"StrOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"BaseCumulativeTransformOutputParser",
|
||||
"SimpleJsonOutputParser",
|
||||
"XMLOutputParser",
|
||||
"JsonOutputParser",
|
||||
"PydanticOutputParser",
|
||||
"JsonOutputToolsParser",
|
||||
"BaseGenerationOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseTransformOutputParser",
|
||||
"CommaSeparatedListOutputParser",
|
||||
"JsonOutputKeyToolsParser",
|
||||
"JsonOutputParser",
|
||||
"JsonOutputToolsParser",
|
||||
"ListOutputParser",
|
||||
"MarkdownListOutputParser",
|
||||
"NumberedListOutputParser",
|
||||
"PydanticOutputParser",
|
||||
"PydanticToolsParser",
|
||||
"SimpleJsonOutputParser",
|
||||
"StrOutputParser",
|
||||
"XMLOutputParser",
|
||||
]
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -132,6 +132,6 @@ SimpleJsonOutputParser = JsonOutputParser
|
||||
__all__ = [
|
||||
"JsonOutputParser",
|
||||
"SimpleJsonOutputParser", # For backwards compatibility
|
||||
"parse_partial_json", # For backwards compatibility
|
||||
"parse_and_check_json_markdown", # For backwards compatibility
|
||||
"parse_partial_json", # For backwards compatibility
|
||||
]
|
||||
|
@ -3,7 +3,7 @@
|
||||
import contextlib
|
||||
import re
|
||||
import xml
|
||||
import xml.etree.ElementTree as ET # noqa: N817
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from xml.etree.ElementTree import TreeBuilder
|
||||
|
@ -70,21 +70,21 @@ __all__ = (
|
||||
"ChatMessagePromptTemplate",
|
||||
"ChatPromptTemplate",
|
||||
"DictPromptTemplate",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"FewShotPromptTemplate",
|
||||
"FewShotPromptWithTemplates",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
"HumanMessagePromptTemplate",
|
||||
"MessagesPlaceholder",
|
||||
"PipelinePromptTemplate",
|
||||
"PromptTemplate",
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"format_document",
|
||||
"aformat_document",
|
||||
"check_valid_template",
|
||||
"format_document",
|
||||
"get_template_variables",
|
||||
"jinja2_formatter",
|
||||
"load_prompt",
|
||||
"validate_jinja2",
|
||||
)
|
||||
|
||||
|
@ -445,9 +445,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
raise ValueError(msg)
|
||||
prompt = []
|
||||
for tmpl in template:
|
||||
if (
|
||||
isinstance(tmpl, str)
|
||||
or isinstance(tmpl, dict)
|
||||
if isinstance(tmpl, str) or (
|
||||
isinstance(tmpl, dict)
|
||||
and "text" in tmpl
|
||||
and set(tmpl.keys()) <= {"type", "text"}
|
||||
):
|
||||
@ -524,7 +523,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
|
||||
raise ValueError(msg)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
msg = f"Invalid template: {template}"
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def from_template_file(
|
||||
@ -1000,7 +999,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
if isinstance(
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other]).partial(
|
||||
return ChatPromptTemplate(messages=[*self.messages, other]).partial(
|
||||
**partials
|
||||
)
|
||||
if isinstance(other, (list, tuple)):
|
||||
@ -1010,7 +1009,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
)
|
||||
if isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
|
||||
return ChatPromptTemplate(messages=[*self.messages, prompt]).partial(
|
||||
**partials
|
||||
)
|
||||
msg = f"Unsupported operand type for +: {type(other)}"
|
||||
|
@ -84,15 +84,11 @@ def _load_examples(config: dict) -> dict:
|
||||
|
||||
def _load_output_parser(config: dict) -> dict:
|
||||
"""Load output parser."""
|
||||
if "output_parser" in config and config["output_parser"]:
|
||||
_config = config.pop("output_parser")
|
||||
output_parser_type = _config.pop("_type")
|
||||
if output_parser_type == "default":
|
||||
output_parser = StrOutputParser(**_config)
|
||||
else:
|
||||
if _config := config.get("output_parser"):
|
||||
if output_parser_type := _config.get("_type") != "default":
|
||||
msg = f"Unsupported output parser {output_parser_type}"
|
||||
raise ValueError(msg)
|
||||
config["output_parser"] = output_parser
|
||||
config["output_parser"] = StrOutputParser(**_config)
|
||||
return config
|
||||
|
||||
|
||||
|
@ -153,10 +153,8 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
NotImplementedError: If the first element of `others`
|
||||
is not a language model.
|
||||
"""
|
||||
if (
|
||||
others
|
||||
and isinstance(others[0], BaseLanguageModel)
|
||||
or hasattr(others[0], "with_structured_output")
|
||||
if (others and isinstance(others[0], BaseLanguageModel)) or hasattr(
|
||||
others[0], "with_structured_output"
|
||||
):
|
||||
return RunnableSequence(
|
||||
self,
|
||||
|
@ -60,19 +60,15 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"chain",
|
||||
"AddableDict",
|
||||
"ConfigurableField",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldMultiOption",
|
||||
"ConfigurableFieldSingleOption",
|
||||
"ConfigurableFieldSpec",
|
||||
"ensure_config",
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
"RunnableAssign",
|
||||
"RunnableBinding",
|
||||
"RunnableBranch",
|
||||
"RunnableConfig",
|
||||
@ -81,14 +77,18 @@ __all__ = (
|
||||
"RunnableMap",
|
||||
"RunnableParallel",
|
||||
"RunnablePassthrough",
|
||||
"RunnableAssign",
|
||||
"RunnablePick",
|
||||
"RunnableSequence",
|
||||
"RunnableSerializable",
|
||||
"RunnableWithFallbacks",
|
||||
"RunnableWithMessageHistory",
|
||||
"get_config_list",
|
||||
"aadd",
|
||||
"add",
|
||||
"chain",
|
||||
"ensure_config",
|
||||
"get_config_list",
|
||||
"patch_config",
|
||||
"run_in_executor",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -2799,7 +2799,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
Returns:
|
||||
A list of Runnables.
|
||||
"""
|
||||
return [self.first] + self.middle + [self.last]
|
||||
return [self.first, *self.middle, self.last]
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -3353,7 +3353,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
from langchain_core.beta.runnables.context import config_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
steps = [self.first, *self.middle, self.last]
|
||||
config = config_with_context(config, self.steps)
|
||||
|
||||
# transform the input stream of each step with the next
|
||||
@ -3380,7 +3380,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
from langchain_core.beta.runnables.context import aconfig_with_context
|
||||
|
||||
steps = [self.first] + self.middle + [self.last]
|
||||
steps = [self.first, *self.middle, self.last]
|
||||
config = aconfig_with_context(config, self.steps)
|
||||
|
||||
# stream the last steps
|
||||
@ -4203,7 +4203,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
if not hasattr(self, "_transform"):
|
||||
msg = f"{repr(self)} only supports async methods."
|
||||
msg = f"{self!r} only supports async methods."
|
||||
raise NotImplementedError(msg)
|
||||
return self._transform_stream_with_config(
|
||||
input,
|
||||
@ -4238,7 +4238,7 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
if not hasattr(self, "_atransform"):
|
||||
msg = f"{repr(self)} only supports sync methods."
|
||||
msg = f"{self!r} only supports sync methods."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._atransform_stream_with_config(
|
||||
@ -5781,7 +5781,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
config_factories=[listener_config_factory] + self.config_factories,
|
||||
config_factories=[listener_config_factory, *self.config_factories],
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
@ -562,7 +562,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
||||
self.which.name or self.which.id,
|
||||
(
|
||||
(v, v)
|
||||
for v in list(self.alternatives.keys()) + [self.default_key]
|
||||
for v in [*list(self.alternatives.keys()), self.default_key]
|
||||
),
|
||||
)
|
||||
_enums_for_spec[self.which] = cast("type[StrEnum]", which_enum)
|
||||
|
@ -111,15 +111,15 @@ class AsciiCanvas:
|
||||
self.point(x0, y0, char)
|
||||
elif abs(dx) >= abs(dy):
|
||||
for x in range(x0, x1 + 1):
|
||||
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
|
||||
y = y0 if dx == 0 else y0 + round((x - x0) * dy / float(dx))
|
||||
self.point(x, y, char)
|
||||
elif y0 < y1:
|
||||
for y in range(y0, y1 + 1):
|
||||
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
|
||||
x = x0 if dy == 0 else x0 + round((y - y0) * dx / float(dy))
|
||||
self.point(x, y, char)
|
||||
else:
|
||||
for y in range(y1, y0 + 1):
|
||||
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
|
||||
x = x0 if dy == 0 else x1 + round((y - y1) * dx / float(dy))
|
||||
self.point(x, y, char)
|
||||
|
||||
def text(self, x: int, y: int, text: str) -> None:
|
||||
@ -291,8 +291,8 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
||||
maxx = max(xlist)
|
||||
maxy = max(ylist)
|
||||
|
||||
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
|
||||
canvas_lines = int(round(maxy - miny))
|
||||
canvas_cols = math.ceil(math.ceil(maxx) - math.floor(minx)) + 1
|
||||
canvas_lines = round(maxy - miny)
|
||||
|
||||
canvas = AsciiCanvas(canvas_cols, canvas_lines)
|
||||
|
||||
@ -305,10 +305,10 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
||||
start = edge.view.pts[index - 1]
|
||||
end = edge.view.pts[index]
|
||||
|
||||
start_x = int(round(start[0] - minx))
|
||||
start_y = int(round(start[1] - miny))
|
||||
end_x = int(round(end[0] - minx))
|
||||
end_y = int(round(end[1] - miny))
|
||||
start_x = round(start[0] - minx)
|
||||
start_y = round(start[1] - miny)
|
||||
end_x = round(end[0] - minx)
|
||||
end_y = round(end[1] - miny)
|
||||
|
||||
if start_x < 0 or start_y < 0 or end_x < 0 or end_y < 0:
|
||||
msg = (
|
||||
@ -328,12 +328,12 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
|
||||
y = vertex.view.xy[1]
|
||||
|
||||
canvas.box(
|
||||
int(round(x - minx)),
|
||||
int(round(y - miny)),
|
||||
round(x - minx),
|
||||
round(y - miny),
|
||||
vertex.view.w,
|
||||
vertex.view.h,
|
||||
)
|
||||
|
||||
canvas.text(int(round(x - minx)) + 1, int(round(y - miny)) + 1, vertex.data)
|
||||
canvas.text(round(x - minx) + 1, round(y - miny) + 1, vertex.data)
|
||||
|
||||
return canvas.draw()
|
||||
|
@ -442,7 +442,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
if self.input_messages_key:
|
||||
key = self.input_messages_key
|
||||
elif len(input_val) == 1:
|
||||
key = list(input_val.keys())[0]
|
||||
key = next(iter(input_val.keys()))
|
||||
else:
|
||||
key = "input"
|
||||
input_val = input_val[key]
|
||||
@ -472,7 +472,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
||||
f"Got {input_val}."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
def _get_output_messages(
|
||||
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
|
||||
@ -484,7 +484,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
if self.output_messages_key:
|
||||
key = self.output_messages_key
|
||||
elif len(output_val) == 1:
|
||||
key = list(output_val.keys())[0]
|
||||
key = next(iter(output_val.keys()))
|
||||
else:
|
||||
key = "output"
|
||||
# If you are wrapping a chat model directly
|
||||
@ -507,7 +507,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
|
||||
f"Got {output_val}."
|
||||
)
|
||||
raise ValueError(msg) # noqa: TRY004
|
||||
raise ValueError(msg)
|
||||
|
||||
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
|
||||
hist: BaseChatMessageHistory = config["configurable"]["message_history"]
|
||||
|
@ -699,7 +699,7 @@ def get_unique_config_specs(
|
||||
else:
|
||||
msg = (
|
||||
"RunnableSequence contains conflicting config specs"
|
||||
f"for {id}: {[first] + others}"
|
||||
f"for {id}: {[first, *others]}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return unique
|
||||
@ -772,9 +772,8 @@ def is_async_generator(
|
||||
TypeGuard[Callable[..., AsyncIterator]: True if the function is
|
||||
an async generator, False otherwise.
|
||||
"""
|
||||
return (
|
||||
inspect.isasyncgenfunction(func)
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
return inspect.isasyncgenfunction(func) or (
|
||||
hasattr(func, "__call__") # noqa: B004
|
||||
and inspect.isasyncgenfunction(func.__call__)
|
||||
)
|
||||
|
||||
@ -791,8 +790,7 @@ def is_async_callable(
|
||||
TypeGuard[Callable[..., Awaitable]: True if the function is async,
|
||||
False otherwise.
|
||||
"""
|
||||
return (
|
||||
asyncio.iscoroutinefunction(func)
|
||||
or hasattr(func, "__call__") # noqa: B004
|
||||
return asyncio.iscoroutinefunction(func) or (
|
||||
hasattr(func, "__call__") # noqa: B004
|
||||
and asyncio.iscoroutinefunction(func.__call__)
|
||||
)
|
||||
|
@ -67,7 +67,7 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
|
||||
for pkg in reversed(order_by):
|
||||
if pkg in all_packages:
|
||||
all_packages.remove(pkg)
|
||||
all_packages = [pkg] + list(all_packages)
|
||||
all_packages = [pkg, *list(all_packages)]
|
||||
|
||||
system_info = {
|
||||
"OS": platform.system(),
|
||||
|
@ -53,25 +53,25 @@ if TYPE_CHECKING:
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
|
||||
__all__ = (
|
||||
"FILTERED_ARGS",
|
||||
"ArgsSchema",
|
||||
"BaseTool",
|
||||
"BaseToolkit",
|
||||
"FILTERED_ARGS",
|
||||
"SchemaAnnotationError",
|
||||
"ToolException",
|
||||
"InjectedToolArg",
|
||||
"InjectedToolCallId",
|
||||
"_get_runnable_config_param",
|
||||
"create_schema_from_function",
|
||||
"convert_runnable_to_tool",
|
||||
"tool",
|
||||
"RetrieverInput",
|
||||
"SchemaAnnotationError",
|
||||
"StructuredTool",
|
||||
"Tool",
|
||||
"ToolException",
|
||||
"ToolsRenderer",
|
||||
"_get_runnable_config_param",
|
||||
"convert_runnable_to_tool",
|
||||
"create_retriever_tool",
|
||||
"create_schema_from_function",
|
||||
"render_text_description",
|
||||
"render_text_description_and_args",
|
||||
"RetrieverInput",
|
||||
"create_retriever_tool",
|
||||
"Tool",
|
||||
"StructuredTool",
|
||||
"tool",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -273,7 +273,7 @@ def create_schema_from_function(
|
||||
# Handle classmethods and instance methods
|
||||
existing_params: list[str] = list(sig.parameters.keys())
|
||||
if existing_params and existing_params[0] in ("self", "cls") and in_class:
|
||||
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
|
||||
filter_args_ = [existing_params[0], *list(FILTERED_ARGS)]
|
||||
else:
|
||||
filter_args_ = list(FILTERED_ARGS)
|
||||
|
||||
@ -991,10 +991,8 @@ def _format_output(
|
||||
|
||||
def _is_message_content_type(obj: Any) -> bool:
|
||||
"""Check for OpenAI or Anthropic format tool message content."""
|
||||
return (
|
||||
isinstance(obj, str)
|
||||
or isinstance(obj, list)
|
||||
and all(_is_message_content_block(e) for e in obj)
|
||||
return isinstance(obj, str) or (
|
||||
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
|
||||
)
|
||||
|
||||
|
||||
|
@ -214,7 +214,7 @@ def tool(
|
||||
monkey: The baz.
|
||||
\"\"\"
|
||||
return bar
|
||||
""" # noqa: D214,D405,D410,D411,D412,D416
|
||||
""" # noqa: D214, D410, D411
|
||||
|
||||
def _create_tool_factory(
|
||||
tool_name: str,
|
||||
@ -367,7 +367,7 @@ def _get_schema_from_runnable_and_arg_types(
|
||||
msg = (
|
||||
"Tool input must be str or dict. If dict, dict arguments must be "
|
||||
"typed. Either annotate types (e.g., with TypedDict) or pass "
|
||||
f"arg_types into `.as_tool` to specify. {str(e)}"
|
||||
f"arg_types into `.as_tool` to specify. {e}"
|
||||
)
|
||||
raise TypeError(msg) from e
|
||||
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}
|
||||
|
@ -26,13 +26,13 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = (
|
||||
"BaseTracer",
|
||||
"ConsoleCallbackHandler",
|
||||
"EvaluatorCallbackHandler",
|
||||
"LangChainTracer",
|
||||
"ConsoleCallbackHandler",
|
||||
"LogStreamCallbackHandler",
|
||||
"Run",
|
||||
"RunLog",
|
||||
"RunLogPatch",
|
||||
"LogStreamCallbackHandler",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -955,7 +955,7 @@ async def _astream_events_implementation_v2(
|
||||
if callbacks is None:
|
||||
config["callbacks"] = [event_streamer]
|
||||
elif isinstance(callbacks, list):
|
||||
config["callbacks"] = callbacks + [event_streamer]
|
||||
config["callbacks"] = [*callbacks, event_streamer]
|
||||
elif isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.copy()
|
||||
callbacks.add_handler(event_streamer, inherit=True)
|
||||
|
@ -632,7 +632,7 @@ async def _astream_log_implementation(
|
||||
if callbacks is None:
|
||||
config["callbacks"] = [stream]
|
||||
elif isinstance(callbacks, list):
|
||||
config["callbacks"] = callbacks + [stream]
|
||||
config["callbacks"] = [*callbacks, stream]
|
||||
elif isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.copy()
|
||||
callbacks.add_handler(stream, inherit=True)
|
||||
|
@ -95,9 +95,7 @@ class FunctionCallbackHandler(BaseTracer):
|
||||
parents = self.get_parents(run)[::-1]
|
||||
return " > ".join(
|
||||
f"{parent.run_type}:{parent.name}"
|
||||
if i != len(parents) - 1
|
||||
else f"{parent.run_type}:{parent.name}"
|
||||
for i, parent in enumerate(parents + [run])
|
||||
for i, parent in enumerate([*parents, run])
|
||||
)
|
||||
|
||||
# logging methods
|
||||
|
@ -38,32 +38,32 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"build_extra_kwargs",
|
||||
"StrictFormatter",
|
||||
"abatch_iterate",
|
||||
"batch_iterate",
|
||||
"build_extra_kwargs",
|
||||
"check_package_version",
|
||||
"comma_list",
|
||||
"convert_to_secret_str",
|
||||
"formatter",
|
||||
"from_env",
|
||||
"get_bolded_text",
|
||||
"get_color_mapping",
|
||||
"get_colored_text",
|
||||
"get_from_dict_or_env",
|
||||
"get_from_env",
|
||||
"get_pydantic_field_names",
|
||||
"guard_import",
|
||||
"image",
|
||||
"mock_now",
|
||||
"pre_init",
|
||||
"print_text",
|
||||
"raise_for_status_with_text",
|
||||
"xor_args",
|
||||
"try_load_from_hub",
|
||||
"image",
|
||||
"get_from_env",
|
||||
"get_from_dict_or_env",
|
||||
"stringify_dict",
|
||||
"comma_list",
|
||||
"stringify_value",
|
||||
"pre_init",
|
||||
"batch_iterate",
|
||||
"abatch_iterate",
|
||||
"from_env",
|
||||
"secret_from_env",
|
||||
"stringify_dict",
|
||||
"stringify_value",
|
||||
"try_load_from_hub",
|
||||
"xor_args",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -31,7 +31,9 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
|
||||
merged = left.copy()
|
||||
for right in others:
|
||||
for right_k, right_v in right.items():
|
||||
if right_k not in merged or right_v is not None and merged[right_k] is None:
|
||||
if right_k not in merged or (
|
||||
right_v is not None and merged[right_k] is None
|
||||
):
|
||||
merged[right_k] = right_v
|
||||
elif right_v is None:
|
||||
continue
|
||||
|
@ -144,7 +144,7 @@ async def tee_peer(
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
async with lock:
|
||||
# this peer is done – remove its buffer
|
||||
# this peer is done - remove its buffer
|
||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||
if peer_buffer is buffer:
|
||||
peers.pop(idx)
|
||||
|
@ -42,8 +42,8 @@ def get_from_dict_or_env(
|
||||
"""
|
||||
if isinstance(key, (list, tuple)):
|
||||
for k in key:
|
||||
if k in data and data[k]:
|
||||
return data[k]
|
||||
if value := data.get(k):
|
||||
return value
|
||||
|
||||
if isinstance(key, str) and key in data and data[key]:
|
||||
return data[key]
|
||||
@ -70,8 +70,8 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
||||
ValueError: If the key is not in the dictionary and no default value is
|
||||
provided or if the environment variable is not set.
|
||||
"""
|
||||
if env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
if env_value := os.getenv(env_key):
|
||||
return env_value
|
||||
if default is not None:
|
||||
return default
|
||||
msg = (
|
||||
|
@ -81,7 +81,7 @@ def tee_peer(
|
||||
yield buffer.popleft()
|
||||
finally:
|
||||
with lock:
|
||||
# this peer is done – remove its buffer
|
||||
# this peer is done - remove its buffer
|
||||
for idx, peer_buffer in enumerate(peers): # pragma: no branch
|
||||
if peer_buffer is buffer:
|
||||
peers.pop(idx)
|
||||
|
@ -571,7 +571,7 @@ def render(
|
||||
padding=padding,
|
||||
def_ldel=def_ldel,
|
||||
def_rdel=def_rdel,
|
||||
scopes=data and [data] + scopes or scopes,
|
||||
scopes=(data and [data, *scopes]) or scopes,
|
||||
warn=warn,
|
||||
keep=keep,
|
||||
),
|
||||
@ -601,7 +601,7 @@ def render(
|
||||
# For every item in the scope
|
||||
for thing in scope:
|
||||
# Append it as the most recent scope and render
|
||||
new_scope = [thing] + scopes
|
||||
new_scope = [thing, *scopes]
|
||||
rend = render(
|
||||
template=tags,
|
||||
scopes=new_scope,
|
||||
|
@ -9,10 +9,10 @@ if TYPE_CHECKING:
|
||||
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
|
||||
|
||||
__all__ = (
|
||||
"VectorStore",
|
||||
"VST",
|
||||
"VectorStoreRetriever",
|
||||
"InMemoryVectorStore",
|
||||
"VectorStore",
|
||||
"VectorStoreRetriever",
|
||||
)
|
||||
|
||||
_dynamic_imports = {
|
||||
|
@ -995,7 +995,7 @@ class VectorStore(ABC):
|
||||
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
|
||||
)
|
||||
"""
|
||||
tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
|
||||
tags = kwargs.pop("tags", None) or [*self._get_retriever_tags()]
|
||||
return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)
|
||||
|
||||
|
||||
|
@ -155,7 +155,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
|
||||
|
||||
""" # noqa: E501
|
||||
"""
|
||||
|
||||
def __init__(self, embedding: Embeddings) -> None:
|
||||
"""Initialize with the given embedding function.
|
||||
|
@ -91,6 +91,7 @@ ignore = [
|
||||
"ISC001", # Messes with the formatter
|
||||
"PERF203", # Rarely useful
|
||||
"PLR09", # Too many something (arg, statements, etc)
|
||||
"RUF012", # Doesn't play well with Pydantic
|
||||
"TC001", # Doesn't play well with Pydantic
|
||||
"TC002", # Doesn't play well with Pydantic
|
||||
"TC003", # Doesn't play well with Pydantic
|
||||
@ -104,7 +105,6 @@ ignore = [
|
||||
"BLE",
|
||||
"ERA",
|
||||
"PLR2004",
|
||||
"RUF",
|
||||
]
|
||||
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
|
@ -12,8 +12,7 @@ if __name__ == "__main__":
|
||||
for file in files:
|
||||
try:
|
||||
module_name = "".join(
|
||||
random.choice(string.ascii_letters)
|
||||
for _ in range(20) # noqa: S311
|
||||
random.choice(string.ascii_letters) for _ in range(20)
|
||||
)
|
||||
SourceFileLoader(module_name, file).load_module()
|
||||
except Exception:
|
||||
|
@ -37,7 +37,7 @@ def test_selector_add_example(selector: LengthBasedExampleSelector) -> None:
|
||||
selector.add_example(new_example)
|
||||
short_question = "Short question?"
|
||||
output = selector.select_examples({"question": short_question})
|
||||
assert output == EXAMPLES + [new_example]
|
||||
assert output == [*EXAMPLES, new_example]
|
||||
|
||||
|
||||
def test_selector_trims_one_example(selector: LengthBasedExampleSelector) -> None:
|
||||
|
@ -3,7 +3,7 @@ from langchain_core.indexing import __all__
|
||||
|
||||
def test_all() -> None:
|
||||
"""Use to catch obvious breaking changes."""
|
||||
assert list(__all__) == sorted(__all__, key=str.lower)
|
||||
assert list(__all__) == sorted(__all__, key=str)
|
||||
assert set(__all__) == {
|
||||
"aindex",
|
||||
"DeleteResponse",
|
||||
|
@ -191,14 +191,14 @@ def test_format_instructions_preserves_language() -> None:
|
||||
|
||||
description = (
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
||||
)
|
||||
|
||||
class Foo(BaseModel):
|
||||
hello: str = Field(
|
||||
description=(
|
||||
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
|
||||
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -380,7 +380,7 @@ def test_chat_prompt_template_with_messages(
|
||||
messages: list[BaseMessagePromptTemplate],
|
||||
) -> None:
|
||||
chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||
messages + [HumanMessage(content="foo")]
|
||||
[*messages, HumanMessage(content="foo")]
|
||||
)
|
||||
assert sorted(chat_prompt_template.input_variables) == sorted(
|
||||
[
|
||||
|
@ -168,7 +168,7 @@ class FakeTracer(BaseTracer):
|
||||
self.runs.append(self._copy_run(run))
|
||||
|
||||
def flattened_runs(self) -> list[Run]:
|
||||
q = [] + self.runs
|
||||
q = [*self.runs]
|
||||
result = []
|
||||
while q:
|
||||
parent = q.pop()
|
||||
|
@ -2312,7 +2312,7 @@ def test_injected_arg_with_complex_type() -> None:
|
||||
self.value = "bar"
|
||||
|
||||
@tool
|
||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
|
||||
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
|
||||
"""Tool that has an injected tool arg."""
|
||||
return foo.value
|
||||
|
||||
@ -2488,7 +2488,7 @@ def test_simple_tool_args_schema_dict() -> None:
|
||||
|
||||
def test_empty_string_tool_call_id() -> None:
|
||||
@tool
|
||||
def foo(x: int) -> str: # noqa: ARG001
|
||||
def foo(x: int) -> str:
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@ -2500,7 +2500,7 @@ def test_empty_string_tool_call_id() -> None:
|
||||
def test_tool_decorator_description() -> None:
|
||||
# test basic tool
|
||||
@tool
|
||||
def foo(x: int) -> str: # noqa: ARG001
|
||||
def foo(x: int) -> str:
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@ -2512,7 +2512,7 @@ def test_tool_decorator_description() -> None:
|
||||
|
||||
# test basic tool with description
|
||||
@tool(description="description")
|
||||
def foo_description(x: int) -> str: # noqa: ARG001
|
||||
def foo_description(x: int) -> str:
|
||||
"""Foo."""
|
||||
return "hi"
|
||||
|
||||
@ -2531,7 +2531,7 @@ def test_tool_decorator_description() -> None:
|
||||
x: int
|
||||
|
||||
@tool(args_schema=ArgsSchema)
|
||||
def foo_args_schema(x: int) -> str: # noqa: ARG001
|
||||
def foo_args_schema(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema.description == "Bar."
|
||||
@ -2543,7 +2543,7 @@ def test_tool_decorator_description() -> None:
|
||||
)
|
||||
|
||||
@tool(description="description", args_schema=ArgsSchema)
|
||||
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
|
||||
def foo_args_schema_description(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_schema_description.description == "description"
|
||||
@ -2565,11 +2565,11 @@ def test_tool_decorator_description() -> None:
|
||||
}
|
||||
|
||||
@tool(args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
|
||||
def foo_args_jsons_schema(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
@tool(description="description", args_schema=args_json_schema)
|
||||
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
|
||||
def foo_args_jsons_schema_with_description(x: int) -> str:
|
||||
return "hi"
|
||||
|
||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||
@ -2629,10 +2629,10 @@ def test_title_property_preserved() -> None:
|
||||
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
|
||||
|
||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||
def sync_no_op(foo: int) -> str:
|
||||
return "good"
|
||||
|
||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||
async def async_no_op(foo: int) -> str:
|
||||
return "good"
|
||||
|
||||
tool = StructuredTool(
|
||||
@ -2677,10 +2677,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
|
||||
def test_tool_invoke_does_not_mutate_inputs() -> None:
|
||||
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
|
||||
|
||||
def sync_no_op(foo: int) -> str: # noqa: ARG001
|
||||
def sync_no_op(foo: int) -> str:
|
||||
return "good"
|
||||
|
||||
async def async_no_op(foo: int) -> str: # noqa: ARG001
|
||||
async def async_no_op(foo: int) -> str:
|
||||
return "good"
|
||||
|
||||
tool = StructuredTool(
|
||||
|
@ -39,7 +39,7 @@ async def test_same_event_loop() -> None:
|
||||
**item,
|
||||
}
|
||||
|
||||
asyncio.create_task(producer())
|
||||
producer_task = asyncio.create_task(producer())
|
||||
|
||||
items = [item async for item in consumer()]
|
||||
|
||||
@ -57,6 +57,8 @@ async def test_same_event_loop() -> None:
|
||||
f"delta_time: {delta_time}"
|
||||
)
|
||||
|
||||
await producer_task
|
||||
|
||||
|
||||
async def test_queue_for_streaming_via_sync_call() -> None:
|
||||
"""Test via async -> sync -> async path."""
|
||||
|
Loading…
Reference in New Issue
Block a user