core: Add ruff rules RUF (#29353)

See https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
Mostly:
* [RUF022](https://docs.astral.sh/ruff/rules/unsorted-dunder-all/)
(unsorted `__all__`)
* [RUF100](https://docs.astral.sh/ruff/rules/unused-noqa/) (unused noqa)
*
[RUF021](https://docs.astral.sh/ruff/rules/parenthesize-chained-operators/)
(parenthesize-chained-operators)
*
[RUF015](https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element/)
(unnecessary-iterable-allocation-for-first-element)
*
[RUF005](https://docs.astral.sh/ruff/rules/collection-literal-concatenation/)
(collection-literal-concatenation)
* [RUF046](https://docs.astral.sh/ruff/rules/unnecessary-cast-to-int/)
(unnecessary-cast-to-int)

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2025-05-15 21:43:57 +02:00 committed by GitHub
parent 6cd1aadf60
commit a8f2ddee31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 3175 additions and 3183 deletions

View File

@ -30,15 +30,15 @@ if TYPE_CHECKING:
from .path import as_import_path, get_relative_path
__all__ = (
"LangChainBetaWarning",
"LangChainDeprecationWarning",
"as_import_path",
"beta",
"deprecated",
"get_relative_path",
"LangChainBetaWarning",
"LangChainDeprecationWarning",
"suppress_langchain_beta_warning",
"surface_langchain_beta_warnings",
"suppress_langchain_deprecation_warning",
"surface_langchain_beta_warnings",
"surface_langchain_deprecation_warnings",
"warn_deprecated",
)

View File

@ -54,39 +54,39 @@ if TYPE_CHECKING:
)
__all__ = (
"dispatch_custom_event",
"adispatch_custom_event",
"RetrieverManagerMixin",
"LLMManagerMixin",
"ChainManagerMixin",
"ToolManagerMixin",
"Callbacks",
"CallbackManagerMixin",
"RunManagerMixin",
"BaseCallbackHandler",
"AsyncCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"RunManager",
"ParentRunManager",
"AsyncRunManager",
"AsyncParentRunManager",
"CallbackManagerForLLMRun",
"AsyncCallbackManagerForLLMRun",
"CallbackManagerForChainRun",
"AsyncCallbackManagerForChainRun",
"CallbackManagerForToolRun",
"AsyncCallbackManagerForToolRun",
"CallbackManagerForRetrieverRun",
"AsyncCallbackManagerForRetrieverRun",
"CallbackManager",
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"AsyncCallbackManagerForChainRun",
"AsyncCallbackManagerForLLMRun",
"AsyncCallbackManagerForRetrieverRun",
"AsyncCallbackManagerForToolRun",
"AsyncParentRunManager",
"AsyncRunManager",
"BaseCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"CallbackManager",
"CallbackManagerForChainGroup",
"CallbackManagerForChainRun",
"CallbackManagerForLLMRun",
"CallbackManagerForRetrieverRun",
"CallbackManagerForToolRun",
"CallbackManagerMixin",
"Callbacks",
"ChainManagerMixin",
"FileCallbackHandler",
"LLMManagerMixin",
"ParentRunManager",
"RetrieverManagerMixin",
"RunManager",
"RunManagerMixin",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
"FileCallbackHandler",
"ToolManagerMixin",
"UsageMetadataCallbackHandler",
"adispatch_custom_event",
"dispatch_custom_event",
"get_usage_metadata_callback",
)

View File

@ -14,8 +14,8 @@ __all__ = (
"BaseLoader",
"Blob",
"BlobLoader",
"PathLike",
"LangSmithLoader",
"PathLike",
)
_dynamic_imports = {

View File

@ -14,7 +14,7 @@ if TYPE_CHECKING:
from .compressor import BaseDocumentCompressor
from .transformers import BaseDocumentTransformer
__all__ = ("Document", "BaseDocumentTransformer", "BaseDocumentCompressor")
__all__ = ("BaseDocumentCompressor", "BaseDocumentTransformer", "Document")
_dynamic_imports = {
"Document": "base",

View File

@ -20,14 +20,14 @@ if TYPE_CHECKING:
)
__all__ = (
"aindex",
"DeleteResponse",
"DocumentIndex",
"index",
"IndexingResult",
"InMemoryRecordManager",
"IndexingResult",
"RecordManager",
"UpsertResponse",
"aindex",
"index",
)
_dynamic_imports = {

View File

@ -68,22 +68,22 @@ if TYPE_CHECKING:
from langchain_core.language_models.llms import LLM, BaseLLM
__all__ = (
"BaseLanguageModel",
"BaseChatModel",
"SimpleChatModel",
"BaseLLM",
"LLM",
"LanguageModelInput",
"get_tokenizer",
"LangSmithParams",
"LanguageModelOutput",
"LanguageModelLike",
"FakeListLLM",
"FakeStreamingListLLM",
"BaseChatModel",
"BaseLLM",
"BaseLanguageModel",
"FakeListChatModel",
"FakeListLLM",
"FakeMessagesListChatModel",
"FakeStreamingListLLM",
"GenericFakeChatModel",
"LangSmithParams",
"LanguageModelInput",
"LanguageModelLike",
"LanguageModelOutput",
"ParrotFakeChatModel",
"SimpleChatModel",
"get_tokenizer",
)
_dynamic_imports = {

View File

@ -23,7 +23,7 @@ def _is_openai_data_block(block: dict) -> bool:
if isinstance(file_data, str):
return True
elif block.get("type") == "input_audio": # noqa: SIM102
elif block.get("type") == "input_audio":
if (input_audio := block.get("input_audio")) and isinstance(input_audio, dict):
audio_data = input_audio.get("data")
audio_format = input_audio.get("format")

View File

@ -354,7 +354,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
@override
def invoke(
@ -1203,7 +1203,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if isinstance(generation, ChatGeneration):
return generation.message
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
async def _call_async(
self,
@ -1219,7 +1219,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if isinstance(generation, ChatGeneration):
return generation.message
msg = "Unexpected generation type"
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
@deprecated("0.1.7", alternative="invoke", removal="1.0")
def call_as_llm(
@ -1261,7 +1261,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if isinstance(result.content, str):
return result.content
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@override
@ -1287,7 +1287,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if isinstance(result.content, str):
return result.content
msg = "Cannot use predict when output is not a string."
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@override

View File

@ -103,7 +103,9 @@ def create_base_retry_decorator(
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(coro)
# TODO: Fix RUF006 - this task should have a reference
# and be awaited somewhere
loop.create_task(coro) # noqa: RUF006
else:
asyncio.run(coro)
except Exception as e:
@ -336,7 +338,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
def _get_ls_params(
self,

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
# the `from langchain_core.load.load import load` absolute import should also work.
from langchain_core.load.load import load
__all__ = ("dumpd", "dumps", "load", "loads", "Serializable")
__all__ = ("Serializable", "dumpd", "dumps", "load", "loads")
_dynamic_imports = {
"dumpd": "dump",

View File

@ -3,7 +3,7 @@
This module contains memory abstractions from LangChain v0.0.x.
These abstractions are now deprecated and will be removed in LangChain v1.0.0.
""" # noqa: E501
"""
from __future__ import annotations

View File

@ -76,28 +76,28 @@ __all__ = (
"HumanMessageChunk",
"InvalidToolCall",
"MessageLikeRepresentation",
"RemoveMessage",
"SystemMessage",
"SystemMessageChunk",
"ToolCall",
"ToolCallChunk",
"ToolMessage",
"ToolMessageChunk",
"RemoveMessage",
"_message_from_dict",
"convert_to_messages",
"convert_to_openai_data_block",
"convert_to_openai_image_block",
"convert_to_messages",
"convert_to_openai_messages",
"filter_messages",
"get_buffer_string",
"is_data_content_block",
"merge_content",
"merge_message_runs",
"message_chunk_to_message",
"message_to_dict",
"messages_from_dict",
"messages_to_dict",
"filter_messages",
"merge_message_runs",
"trim_messages",
"convert_to_openai_messages",
)
_dynamic_imports = {

View File

@ -423,13 +423,13 @@ def add_ai_message_chunks(
id = None
candidates = [left.id] + [o.id for o in others]
# first pass: pick the first nonrun-* id
# first pass: pick the first non-run-* id
for id_ in candidates:
if id_ and not id_.startswith(_LC_ID_PREFIX):
id = id_
break
else:
# second pass: no provider-assigned id found, just take the first nonnull
# second pass: no provider-assigned id found, just take the first non-null
for id_ in candidates:
if id_:
id = id_

View File

@ -101,8 +101,7 @@ class BaseMessage(Serializable):
block
for block in self.content
if isinstance(block, str)
or block.get("type") == "text"
and isinstance(block.get("text"), str)
or (block.get("type") == "text" and isinstance(block.get("text"), str))
]
return "".join(
block if isinstance(block, str) else block["text"] for block in blocks
@ -161,7 +160,7 @@ def merge_content(
merged += content
# If the next chunk is a list, add the current to the start of the list
else:
merged = [merged] + content # type: ignore[assignment,operator]
merged = [merged, *content]
elif isinstance(content, list):
# If both are lists
merged = merge_lists(cast("list", merged), content) # type: ignore[assignment]

View File

@ -885,7 +885,7 @@ def trim_messages(
list_token_counter = token_counter.get_num_tokens_from_messages
elif callable(token_counter):
if (
list(inspect.signature(token_counter).parameters.values())[0].annotation
next(iter(inspect.signature(token_counter).parameters.values())).annotation
is BaseMessage
):
@ -1460,7 +1460,7 @@ def _last_max_tokens(
# Re-reverse the messages and add back the system message if needed
result = reversed_result[::-1]
if system_message:
result = [system_message] + result
result = [system_message, *result]
return result
@ -1543,7 +1543,7 @@ def _get_message_openai_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
return message.role
msg = f"Unknown BaseMessage type {message.__class__}."
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
def _convert_to_openai_tool_calls(tool_calls: list[ToolCall]) -> list[dict]:

View File

@ -47,23 +47,23 @@ if TYPE_CHECKING:
from langchain_core.output_parsers.xml import XMLOutputParser
__all__ = [
"BaseLLMOutputParser",
"BaseGenerationOutputParser",
"BaseOutputParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"NumberedListOutputParser",
"MarkdownListOutputParser",
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
"PydanticOutputParser",
"JsonOutputToolsParser",
"BaseGenerationOutputParser",
"BaseLLMOutputParser",
"BaseOutputParser",
"BaseTransformOutputParser",
"CommaSeparatedListOutputParser",
"JsonOutputKeyToolsParser",
"JsonOutputParser",
"JsonOutputToolsParser",
"ListOutputParser",
"MarkdownListOutputParser",
"NumberedListOutputParser",
"PydanticOutputParser",
"PydanticToolsParser",
"SimpleJsonOutputParser",
"StrOutputParser",
"XMLOutputParser",
]
_dynamic_imports = {

View File

@ -132,6 +132,6 @@ SimpleJsonOutputParser = JsonOutputParser
__all__ = [
"JsonOutputParser",
"SimpleJsonOutputParser", # For backwards compatibility
"parse_partial_json", # For backwards compatibility
"parse_and_check_json_markdown", # For backwards compatibility
"parse_partial_json", # For backwards compatibility
]

View File

@ -3,7 +3,7 @@
import contextlib
import re
import xml
import xml.etree.ElementTree as ET # noqa: N817
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder

View File

@ -70,21 +70,21 @@ __all__ = (
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"DictPromptTemplate",
"FewShotChatMessagePromptTemplate",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",
"PromptTemplate",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"format_document",
"aformat_document",
"check_valid_template",
"format_document",
"get_template_variables",
"jinja2_formatter",
"load_prompt",
"validate_jinja2",
)

View File

@ -445,9 +445,8 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
raise ValueError(msg)
prompt = []
for tmpl in template:
if (
isinstance(tmpl, str)
or isinstance(tmpl, dict)
if isinstance(tmpl, str) or (
isinstance(tmpl, dict)
and "text" in tmpl
and set(tmpl.keys()) <= {"type", "text"}
):
@ -524,7 +523,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
raise ValueError(msg)
return cls(prompt=prompt, **kwargs)
msg = f"Invalid template: {template}"
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
@classmethod
def from_template_file(
@ -1000,7 +999,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
if isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other]).partial(
return ChatPromptTemplate(messages=[*self.messages, other]).partial(
**partials
)
if isinstance(other, (list, tuple)):
@ -1010,7 +1009,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
)
if isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
return ChatPromptTemplate(messages=[*self.messages, prompt]).partial(
**partials
)
msg = f"Unsupported operand type for +: {type(other)}"

View File

@ -84,15 +84,11 @@ def _load_examples(config: dict) -> dict:
def _load_output_parser(config: dict) -> dict:
"""Load output parser."""
if "output_parser" in config and config["output_parser"]:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "default":
output_parser = StrOutputParser(**_config)
else:
if _config := config.get("output_parser"):
if output_parser_type := _config.get("_type") != "default":
msg = f"Unsupported output parser {output_parser_type}"
raise ValueError(msg)
config["output_parser"] = output_parser
config["output_parser"] = StrOutputParser(**_config)
return config

View File

@ -153,10 +153,8 @@ class StructuredPrompt(ChatPromptTemplate):
NotImplementedError: If the first element of `others`
is not a language model.
"""
if (
others
and isinstance(others[0], BaseLanguageModel)
or hasattr(others[0], "with_structured_output")
if (others and isinstance(others[0], BaseLanguageModel)) or hasattr(
others[0], "with_structured_output"
):
return RunnableSequence(
self,

View File

@ -60,19 +60,15 @@ if TYPE_CHECKING:
)
__all__ = (
"chain",
"AddableDict",
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"ConfigurableFieldSingleOption",
"ConfigurableFieldSpec",
"ensure_config",
"run_in_executor",
"patch_config",
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableAssign",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
@ -81,14 +77,18 @@ __all__ = (
"RunnableMap",
"RunnableParallel",
"RunnablePassthrough",
"RunnableAssign",
"RunnablePick",
"RunnableSequence",
"RunnableSerializable",
"RunnableWithFallbacks",
"RunnableWithMessageHistory",
"get_config_list",
"aadd",
"add",
"chain",
"ensure_config",
"get_config_list",
"patch_config",
"run_in_executor",
)
_dynamic_imports = {

View File

@ -2799,7 +2799,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Returns:
A list of Runnables.
"""
return [self.first] + self.middle + [self.last]
return [self.first, *self.middle, self.last]
@classmethod
@override
@ -3353,7 +3353,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context
steps = [self.first] + self.middle + [self.last]
steps = [self.first, *self.middle, self.last]
config = config_with_context(config, self.steps)
# transform the input stream of each step with the next
@ -3380,7 +3380,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context
steps = [self.first] + self.middle + [self.last]
steps = [self.first, *self.middle, self.last]
config = aconfig_with_context(config, self.steps)
# stream the last steps
@ -4203,7 +4203,7 @@ class RunnableGenerator(Runnable[Input, Output]):
**kwargs: Any,
) -> Iterator[Output]:
if not hasattr(self, "_transform"):
msg = f"{repr(self)} only supports async methods."
msg = f"{self!r} only supports async methods."
raise NotImplementedError(msg)
return self._transform_stream_with_config(
input,
@ -4238,7 +4238,7 @@ class RunnableGenerator(Runnable[Input, Output]):
**kwargs: Any,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
msg = f"{repr(self)} only supports sync methods."
msg = f"{self!r} only supports sync methods."
raise NotImplementedError(msg)
return self._atransform_stream_with_config(
@ -5781,7 +5781,7 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[listener_config_factory] + self.config_factories,
config_factories=[listener_config_factory, *self.config_factories],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)

View File

@ -562,7 +562,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
self.which.name or self.which.id,
(
(v, v)
for v in list(self.alternatives.keys()) + [self.default_key]
for v in [*list(self.alternatives.keys()), self.default_key]
),
)
_enums_for_spec[self.which] = cast("type[StrEnum]", which_enum)

View File

@ -111,15 +111,15 @@ class AsciiCanvas:
self.point(x0, y0, char)
elif abs(dx) >= abs(dy):
for x in range(x0, x1 + 1):
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
y = y0 if dx == 0 else y0 + round((x - x0) * dy / float(dx))
self.point(x, y, char)
elif y0 < y1:
for y in range(y0, y1 + 1):
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
x = x0 if dy == 0 else x0 + round((y - y0) * dx / float(dy))
self.point(x, y, char)
else:
for y in range(y1, y0 + 1):
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
x = x0 if dy == 0 else x1 + round((y - y1) * dx / float(dy))
self.point(x, y, char)
def text(self, x: int, y: int, text: str) -> None:
@ -291,8 +291,8 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
maxx = max(xlist)
maxy = max(ylist)
canvas_cols = int(math.ceil(math.ceil(maxx) - math.floor(minx))) + 1
canvas_lines = int(round(maxy - miny))
canvas_cols = math.ceil(math.ceil(maxx) - math.floor(minx)) + 1
canvas_lines = round(maxy - miny)
canvas = AsciiCanvas(canvas_cols, canvas_lines)
@ -305,10 +305,10 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
start = edge.view.pts[index - 1]
end = edge.view.pts[index]
start_x = int(round(start[0] - minx))
start_y = int(round(start[1] - miny))
end_x = int(round(end[0] - minx))
end_y = int(round(end[1] - miny))
start_x = round(start[0] - minx)
start_y = round(start[1] - miny)
end_x = round(end[0] - minx)
end_y = round(end[1] - miny)
if start_x < 0 or start_y < 0 or end_x < 0 or end_y < 0:
msg = (
@ -328,12 +328,12 @@ def draw_ascii(vertices: Mapping[str, str], edges: Sequence[LangEdge]) -> str:
y = vertex.view.xy[1]
canvas.box(
int(round(x - minx)),
int(round(y - miny)),
round(x - minx),
round(y - miny),
vertex.view.w,
vertex.view.h,
)
canvas.text(int(round(x - minx)) + 1, int(round(y - miny)) + 1, vertex.data)
canvas.text(round(x - minx) + 1, round(y - miny) + 1, vertex.data)
return canvas.draw()

View File

@ -442,7 +442,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if self.input_messages_key:
key = self.input_messages_key
elif len(input_val) == 1:
key = list(input_val.keys())[0]
key = next(iter(input_val.keys()))
else:
key = "input"
input_val = input_val[key]
@ -472,7 +472,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
f"Got {input_val}."
)
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
def _get_output_messages(
self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict]
@ -484,7 +484,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
if self.output_messages_key:
key = self.output_messages_key
elif len(output_val) == 1:
key = list(output_val.keys())[0]
key = next(iter(output_val.keys()))
else:
key = "output"
# If you are wrapping a chat model directly
@ -507,7 +507,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
f"Expected str, BaseMessage, list[BaseMessage], or tuple[BaseMessage]. "
f"Got {output_val}."
)
raise ValueError(msg) # noqa: TRY004
raise ValueError(msg)
def _enter_history(self, input: Any, config: RunnableConfig) -> list[BaseMessage]:
hist: BaseChatMessageHistory = config["configurable"]["message_history"]

View File

@ -699,7 +699,7 @@ def get_unique_config_specs(
else:
msg = (
"RunnableSequence contains conflicting config specs"
f"for {id}: {[first] + others}"
f"for {id}: {[first, *others]}"
)
raise ValueError(msg)
return unique
@ -772,9 +772,8 @@ def is_async_generator(
TypeGuard[Callable[..., AsyncIterator]: True if the function is
an async generator, False otherwise.
"""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__") # noqa: B004
return inspect.isasyncgenfunction(func) or (
hasattr(func, "__call__") # noqa: B004
and inspect.isasyncgenfunction(func.__call__)
)
@ -791,8 +790,7 @@ def is_async_callable(
TypeGuard[Callable[..., Awaitable]: True if the function is async,
False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") # noqa: B004
return asyncio.iscoroutinefunction(func) or (
hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)

View File

@ -67,7 +67,7 @@ def print_sys_info(*, additional_pkgs: Sequence[str] = ()) -> None:
for pkg in reversed(order_by):
if pkg in all_packages:
all_packages.remove(pkg)
all_packages = [pkg] + list(all_packages)
all_packages = [pkg, *list(all_packages)]
system_info = {
"OS": platform.system(),

View File

@ -53,25 +53,25 @@ if TYPE_CHECKING:
from langchain_core.tools.structured import StructuredTool
__all__ = (
"FILTERED_ARGS",
"ArgsSchema",
"BaseTool",
"BaseToolkit",
"FILTERED_ARGS",
"SchemaAnnotationError",
"ToolException",
"InjectedToolArg",
"InjectedToolCallId",
"_get_runnable_config_param",
"create_schema_from_function",
"convert_runnable_to_tool",
"tool",
"RetrieverInput",
"SchemaAnnotationError",
"StructuredTool",
"Tool",
"ToolException",
"ToolsRenderer",
"_get_runnable_config_param",
"convert_runnable_to_tool",
"create_retriever_tool",
"create_schema_from_function",
"render_text_description",
"render_text_description_and_args",
"RetrieverInput",
"create_retriever_tool",
"Tool",
"StructuredTool",
"tool",
)
_dynamic_imports = {

View File

@ -273,7 +273,7 @@ def create_schema_from_function(
# Handle classmethods and instance methods
existing_params: list[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
filter_args_ = [existing_params[0], *list(FILTERED_ARGS)]
else:
filter_args_ = list(FILTERED_ARGS)
@ -991,10 +991,8 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content."""
return (
isinstance(obj, str)
or isinstance(obj, list)
and all(_is_message_content_block(e) for e in obj)
return isinstance(obj, str) or (
isinstance(obj, list) and all(_is_message_content_block(e) for e in obj)
)

View File

@ -214,7 +214,7 @@ def tool(
monkey: The baz.
\"\"\"
return bar
""" # noqa: D214,D405,D410,D411,D412,D416
""" # noqa: D214, D410, D411
def _create_tool_factory(
tool_name: str,
@ -367,7 +367,7 @@ def _get_schema_from_runnable_and_arg_types(
msg = (
"Tool input must be str or dict. If dict, dict arguments must be "
"typed. Either annotate types (e.g., with TypedDict) or pass "
f"arg_types into `.as_tool` to specify. {str(e)}"
f"arg_types into `.as_tool` to specify. {e}"
)
raise TypeError(msg) from e
fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()}

View File

@ -26,13 +26,13 @@ if TYPE_CHECKING:
__all__ = (
"BaseTracer",
"ConsoleCallbackHandler",
"EvaluatorCallbackHandler",
"LangChainTracer",
"ConsoleCallbackHandler",
"LogStreamCallbackHandler",
"Run",
"RunLog",
"RunLogPatch",
"LogStreamCallbackHandler",
)
_dynamic_imports = {

View File

@ -955,7 +955,7 @@ async def _astream_events_implementation_v2(
if callbacks is None:
config["callbacks"] = [event_streamer]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [event_streamer]
config["callbacks"] = [*callbacks, event_streamer]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(event_streamer, inherit=True)

View File

@ -632,7 +632,7 @@ async def _astream_log_implementation(
if callbacks is None:
config["callbacks"] = [stream]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [stream]
config["callbacks"] = [*callbacks, stream]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(stream, inherit=True)

View File

@ -95,9 +95,7 @@ class FunctionCallbackHandler(BaseTracer):
parents = self.get_parents(run)[::-1]
return " > ".join(
f"{parent.run_type}:{parent.name}"
if i != len(parents) - 1
else f"{parent.run_type}:{parent.name}"
for i, parent in enumerate(parents + [run])
for i, parent in enumerate([*parents, run])
)
# logging methods

View File

@ -38,32 +38,32 @@ if TYPE_CHECKING:
)
__all__ = (
"build_extra_kwargs",
"StrictFormatter",
"abatch_iterate",
"batch_iterate",
"build_extra_kwargs",
"check_package_version",
"comma_list",
"convert_to_secret_str",
"formatter",
"from_env",
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"get_from_dict_or_env",
"get_from_env",
"get_pydantic_field_names",
"guard_import",
"image",
"mock_now",
"pre_init",
"print_text",
"raise_for_status_with_text",
"xor_args",
"try_load_from_hub",
"image",
"get_from_env",
"get_from_dict_or_env",
"stringify_dict",
"comma_list",
"stringify_value",
"pre_init",
"batch_iterate",
"abatch_iterate",
"from_env",
"secret_from_env",
"stringify_dict",
"stringify_value",
"try_load_from_hub",
"xor_args",
)
_dynamic_imports = {

View File

@ -31,7 +31,9 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
merged = left.copy()
for right in others:
for right_k, right_v in right.items():
if right_k not in merged or right_v is not None and merged[right_k] is None:
if right_k not in merged or (
right_v is not None and merged[right_k] is None
):
merged[right_k] = right_v
elif right_v is None:
continue

View File

@ -144,7 +144,7 @@ async def tee_peer(
yield buffer.popleft()
finally:
async with lock:
# this peer is done remove its buffer
# this peer is done - remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)

View File

@ -42,8 +42,8 @@ def get_from_dict_or_env(
"""
if isinstance(key, (list, tuple)):
for k in key:
if k in data and data[k]:
return data[k]
if value := data.get(k):
return value
if isinstance(key, str) and key in data and data[key]:
return data[key]
@ -70,8 +70,8 @@ def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
ValueError: If the key is not in the dictionary and no default value is
provided or if the environment variable is not set.
"""
if env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
if env_value := os.getenv(env_key):
return env_value
if default is not None:
return default
msg = (

View File

@ -81,7 +81,7 @@ def tee_peer(
yield buffer.popleft()
finally:
with lock:
# this peer is done remove its buffer
# this peer is done - remove its buffer
for idx, peer_buffer in enumerate(peers): # pragma: no branch
if peer_buffer is buffer:
peers.pop(idx)

View File

@ -571,7 +571,7 @@ def render(
padding=padding,
def_ldel=def_ldel,
def_rdel=def_rdel,
scopes=data and [data] + scopes or scopes,
scopes=(data and [data, *scopes]) or scopes,
warn=warn,
keep=keep,
),
@ -601,7 +601,7 @@ def render(
# For every item in the scope
for thing in scope:
# Append it as the most recent scope and render
new_scope = [thing] + scopes
new_scope = [thing, *scopes]
rend = render(
template=tags,
scopes=new_scope,

View File

@ -9,10 +9,10 @@ if TYPE_CHECKING:
from langchain_core.vectorstores.in_memory import InMemoryVectorStore
__all__ = (
"VectorStore",
"VST",
"VectorStoreRetriever",
"InMemoryVectorStore",
"VectorStore",
"VectorStoreRetriever",
)
_dynamic_imports = {

View File

@ -995,7 +995,7 @@ class VectorStore(ABC):
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
)
"""
tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
tags = kwargs.pop("tags", None) or [*self._get_retriever_tags()]
return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)

View File

@ -155,7 +155,7 @@ class InMemoryVectorStore(VectorStore):
[Document(id='2', metadata={'bar': 'baz'}, page_content='thud')]
""" # noqa: E501
"""
def __init__(self, embedding: Embeddings) -> None:
"""Initialize with the given embedding function.

View File

@ -91,6 +91,7 @@ ignore = [
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
"PLR09", # Too many something (arg, statements, etc)
"RUF012", # Doesn't play well with Pydantic
"TC001", # Doesn't play well with Pydantic
"TC002", # Doesn't play well with Pydantic
"TC003", # Doesn't play well with Pydantic
@ -104,7 +105,6 @@ ignore = [
"BLE",
"ERA",
"PLR2004",
"RUF",
]
flake8-type-checking.runtime-evaluated-base-classes = ["pydantic.BaseModel","langchain_core.load.serializable.Serializable","langchain_core.runnables.base.RunnableSerializable"]
flake8-annotations.allow-star-arg-any = true

View File

@ -12,8 +12,7 @@ if __name__ == "__main__":
for file in files:
try:
module_name = "".join(
random.choice(string.ascii_letters)
for _ in range(20) # noqa: S311
random.choice(string.ascii_letters) for _ in range(20)
)
SourceFileLoader(module_name, file).load_module()
except Exception:

View File

@ -37,7 +37,7 @@ def test_selector_add_example(selector: LengthBasedExampleSelector) -> None:
selector.add_example(new_example)
short_question = "Short question?"
output = selector.select_examples({"question": short_question})
assert output == EXAMPLES + [new_example]
assert output == [*EXAMPLES, new_example]
def test_selector_trims_one_example(selector: LengthBasedExampleSelector) -> None:

View File

@ -3,7 +3,7 @@ from langchain_core.indexing import __all__
def test_all() -> None:
"""Use to catch obvious breaking changes."""
assert list(__all__) == sorted(__all__, key=str.lower)
assert list(__all__) == sorted(__all__, key=str)
assert set(__all__) == {
"aindex",
"DeleteResponse",

View File

@ -191,14 +191,14 @@ def test_format_instructions_preserves_language() -> None:
description = (
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
)
class Foo(BaseModel):
hello: str = Field(
description=(
"你好, こんにちは, नमस्ते, Bonjour, Hola, "
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου"
"Olá, 안녕하세요, Jambo, Merhaba, Γειά σου" # noqa: RUF001
)
)

View File

@ -380,7 +380,7 @@ def test_chat_prompt_template_with_messages(
messages: list[BaseMessagePromptTemplate],
) -> None:
chat_prompt_template = ChatPromptTemplate.from_messages(
messages + [HumanMessage(content="foo")]
[*messages, HumanMessage(content="foo")]
)
assert sorted(chat_prompt_template.input_variables) == sorted(
[

View File

@ -168,7 +168,7 @@ class FakeTracer(BaseTracer):
self.runs.append(self._copy_run(run))
def flattened_runs(self) -> list[Run]:
q = [] + self.runs
q = [*self.runs]
result = []
while q:
parent = q.pop()

View File

@ -2312,7 +2312,7 @@ def test_injected_arg_with_complex_type() -> None:
self.value = "bar"
@tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: # noqa: ARG001
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
"""Tool that has an injected tool arg."""
return foo.value
@ -2488,7 +2488,7 @@ def test_simple_tool_args_schema_dict() -> None:
def test_empty_string_tool_call_id() -> None:
@tool
def foo(x: int) -> str: # noqa: ARG001
def foo(x: int) -> str:
"""Foo."""
return "hi"
@ -2500,7 +2500,7 @@ def test_empty_string_tool_call_id() -> None:
def test_tool_decorator_description() -> None:
# test basic tool
@tool
def foo(x: int) -> str: # noqa: ARG001
def foo(x: int) -> str:
"""Foo."""
return "hi"
@ -2512,7 +2512,7 @@ def test_tool_decorator_description() -> None:
# test basic tool with description
@tool(description="description")
def foo_description(x: int) -> str: # noqa: ARG001
def foo_description(x: int) -> str:
"""Foo."""
return "hi"
@ -2531,7 +2531,7 @@ def test_tool_decorator_description() -> None:
x: int
@tool(args_schema=ArgsSchema)
def foo_args_schema(x: int) -> str: # noqa: ARG001
def foo_args_schema(x: int) -> str:
return "hi"
assert foo_args_schema.description == "Bar."
@ -2543,7 +2543,7 @@ def test_tool_decorator_description() -> None:
)
@tool(description="description", args_schema=ArgsSchema)
def foo_args_schema_description(x: int) -> str: # noqa: ARG001
def foo_args_schema_description(x: int) -> str:
return "hi"
assert foo_args_schema_description.description == "description"
@ -2565,11 +2565,11 @@ def test_tool_decorator_description() -> None:
}
@tool(args_schema=args_json_schema)
def foo_args_jsons_schema(x: int) -> str: # noqa: ARG001
def foo_args_jsons_schema(x: int) -> str:
return "hi"
@tool(description="description", args_schema=args_json_schema)
def foo_args_jsons_schema_with_description(x: int) -> str: # noqa: ARG001
def foo_args_jsons_schema_with_description(x: int) -> str:
return "hi"
assert foo_args_jsons_schema.description == "JSON Schema."
@ -2629,10 +2629,10 @@ def test_title_property_preserved() -> None:
async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool asynchronously."""
def sync_no_op(foo: int) -> str: # noqa: ARG001
def sync_no_op(foo: int) -> str:
return "good"
async def async_no_op(foo: int) -> str: # noqa: ARG001
async def async_no_op(foo: int) -> str:
return "good"
tool = StructuredTool(
@ -2677,10 +2677,10 @@ async def test_tool_ainvoke_does_not_mutate_inputs() -> None:
def test_tool_invoke_does_not_mutate_inputs() -> None:
"""Verify that the inputs are not mutated when invoking a tool synchronously."""
def sync_no_op(foo: int) -> str: # noqa: ARG001
def sync_no_op(foo: int) -> str:
return "good"
async def async_no_op(foo: int) -> str: # noqa: ARG001
async def async_no_op(foo: int) -> str:
return "good"
tool = StructuredTool(

View File

@ -39,7 +39,7 @@ async def test_same_event_loop() -> None:
**item,
}
asyncio.create_task(producer())
producer_task = asyncio.create_task(producer())
items = [item async for item in consumer()]
@ -57,6 +57,8 @@ async def test_same_event_loop() -> None:
f"delta_time: {delta_time}"
)
await producer_task
async def test_queue_for_streaming_via_sync_call() -> None:
"""Test via async -> sync -> async path."""

5933
uv.lock

File diff suppressed because it is too large Load Diff