core[patch]: Add ruff rules for flake8-simplify (SIM) (#26848)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-27 22:13:23 +02:00 committed by GitHub
parent de0b48c41a
commit 7809b31b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 277 additions and 384 deletions

View File

@ -127,10 +127,9 @@ def beta(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the annotation of a class."""
try:
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any

View File

@ -198,10 +198,9 @@ def deprecated(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the deprecation of a class."""
try:
# Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any

View File

@ -27,7 +27,7 @@ class FileCallbackHandler(BaseCallbackHandler):
mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None.
"""
self.file = cast(TextIO, open(filename, mode, encoding="utf-8"))
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
self.color = color
def __del__(self) -> None:

View File

@ -2252,14 +2252,14 @@ def _configure(
else:
parent_run_id_ = inheritable_callbacks.parent_run_id
# Break ties between the external tracing context and inherited context
if parent_run_id is not None:
if parent_run_id_ is None:
parent_run_id_ = parent_run_id
if parent_run_id is not None and (
parent_run_id_ is None
# If the LC parent has already been reflected
# in the run tree, we know the run_tree is either the
# same parent or a child of the parent.
elif run_tree and str(parent_run_id_) in run_tree.dotted_order:
parent_run_id_ = parent_run_id
or (run_tree and str(parent_run_id_) in run_tree.dotted_order)
):
parent_run_id_ = parent_run_id
# Otherwise, we assume the LC context has progressed
# beyond the run tree and we should not inherit the parent.
callback_manager = callback_manager_cls(

View File

@ -40,12 +40,11 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
send_to_llm: bool = False,
):
super().__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
if send_to_llm and (observation is None or llm_output is None):
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

View File

@ -92,7 +92,7 @@ class _HashedDocument(Document):
values["metadata_hash"] = metadata_hash
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
_uid = values.get("uid", None)
_uid = values.get("uid")
if _uid is None:
values["uid"] = values["hash_"]

View File

@ -802,10 +802,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
@ -879,10 +876,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if isinstance(self.cache, BaseCache):
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
# We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache
# if it's configured.
@ -1054,10 +1048,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str):
return result.content
@ -1072,20 +1063,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return self(messages, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
@ -1102,10 +1087,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
if not isinstance(serialized, dict):
return
if "type" in serialized and serialized["type"] == "not_implemented":
if "repr" in serialized:
del serialized["repr"]
if (
"type" in serialized
and serialized["type"] == "not_implemented"
and "repr" in serialized
):
del serialized["repr"]
if "graph" in serialized:
del serialized["graph"]

View File

@ -194,10 +194,7 @@ class GenericFakeChatModel(BaseChatModel):
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
if isinstance(message, str):
message_ = AIMessage(content=message)
else:
message_ = message
message_ = AIMessage(content=message) if isinstance(message, str) else message
generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation])

View File

@ -1305,10 +1305,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return self(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="invoke", removal="1.0")
@ -1320,10 +1317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@ -1331,10 +1325,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
return await self._call_async(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@ -1346,10 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
_stop = None if stop is None else list(stop)
content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@ -1384,10 +1372,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm.save(file_path="path/llm.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -86,9 +86,9 @@ class Reviver:
def __call__(self, value: dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "secret"
and value.get("id") is not None
):
[key] = value["id"]
if key in self.secrets_map:
@ -99,9 +99,9 @@ class Reviver:
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "not_implemented"
and value.get("id") is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
@ -109,17 +109,18 @@ class Reviver:
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
value.get("lc") == 1
and value.get("type") == "constructor"
and value.get("id") is not None
):
[*namespace, name] = value["id"]
mapping_key = tuple(value["id"])
if namespace[0] not in self.valid_namespaces:
raise ValueError(f"Invalid namespace: {value}")
# The root namespace ["langchain"] is not a valid identifier.
elif namespace == ["langchain"]:
if (
namespace[0] not in self.valid_namespaces
# The root namespace ["langchain"] is not a valid identifier.
or namespace == ["langchain"]
):
raise ValueError(f"Invalid namespace: {value}")
# Has explicit import path.
elif mapping_key in self.import_mappings:

View File

@ -1,3 +1,4 @@
import contextlib
from abc import ABC
from typing import (
Any,
@ -238,7 +239,7 @@ class Serializable(BaseModel, ABC):
# include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead
for key in secrets.keys():
for key in secrets:
secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None:
lc_kwargs.update({key: secret_value})
@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"id": _id,
"repr": None,
}
try:
with contextlib.suppress(Exception):
result["repr"] = repr(obj)
except Exception:
pass
return result

View File

@ -435,23 +435,22 @@ def filter_messages(
messages = convert_to_messages(messages)
filtered: list[BaseMessage] = []
for msg in messages:
if exclude_names and msg.name in exclude_names:
continue
elif exclude_types and _is_message_type(msg, exclude_types):
continue
elif exclude_ids and msg.id in exclude_ids:
if (
(exclude_names and msg.name in exclude_names)
or (exclude_types and _is_message_type(msg, exclude_types))
or (exclude_ids and msg.id in exclude_ids)
):
continue
else:
pass
# default to inclusion when no inclusion criteria given.
if not (include_types or include_ids or include_names):
filtered.append(msg)
elif include_names and msg.name in include_names:
filtered.append(msg)
elif include_types and _is_message_type(msg, include_types):
filtered.append(msg)
elif include_ids and msg.id in include_ids:
if (
not (include_types or include_ids or include_names)
or (include_names and msg.name in include_names)
or (include_types and _is_message_type(msg, include_types))
or (include_ids and msg.id in include_ids)
):
filtered.append(msg)
else:
pass
@ -961,10 +960,7 @@ def _last_max_tokens(
while messages and not _is_message_type(messages[-1], end_on):
messages.pop()
swapped_system = include_system and isinstance(messages[0], SystemMessage)
if swapped_system:
reversed_ = messages[:1] + messages[1:][::-1]
else:
reversed_ = messages[::-1]
reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
reversed_ = _first_max_tokens(
reversed_,

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
@ -311,8 +312,6 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
try:
with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict

View File

@ -49,12 +49,10 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2:
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema()
if issubclass(pydantic_object, pydantic.BaseModel):
return pydantic_object.model_json_schema()
elif issubclass(pydantic_object, pydantic.v1.BaseModel):
return pydantic_object.schema()
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object.

View File

@ -110,10 +110,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
@ -121,10 +122,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen = acc_gen + chunk_gen
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
@ -138,10 +136,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
async for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
@ -149,10 +148,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen = acc_gen + chunk_gen
acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:

View File

@ -1,3 +1,4 @@
import contextlib
import re
import xml
import xml.etree.ElementTree as ET # noqa: N817
@ -131,11 +132,9 @@ class _StreamingParser:
Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed.
"""
try:
# Ignore ParseError. This will ignore any incomplete XML at the end of the input
with contextlib.suppress(xml.etree.ElementTree.ParseError):
self.pull_parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
class XMLOutputParser(BaseTransformOutputParser):

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import contextlib
import json
import typing
from abc import ABC, abstractmethod
@ -319,10 +320,8 @@ class BasePromptTemplate(
NotImplementedError: If the prompt type is not implemented.
"""
prompt_dict = super().model_dump(**kwargs)
try:
with contextlib.suppress(NotImplementedError):
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
@ -350,10 +349,7 @@ class BasePromptTemplate(
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -59,8 +59,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
ValueError: If neither or both examples and example_selector are provided.
ValueError: If both examples and example_selector are provided.
"""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
examples = values.get("examples")
example_selector = values.get("example_selector")
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"

View File

@ -51,8 +51,8 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
@classmethod
def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None)
example_selector = values.get("example_selector", None)
examples = values.get("examples")
example_selector = values.get("example_selector")
if examples and example_selector:
raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided"
@ -138,7 +138,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
for k in prefix_kwargs:
kwargs.pop(k)
prefix = self.prefix.format(**prefix_kwargs)
@ -146,7 +146,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
for k in suffix_kwargs:
kwargs.pop(k)
suffix = self.suffix.format(
**suffix_kwargs,
@ -182,7 +182,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables
}
for k in prefix_kwargs.keys():
for k in prefix_kwargs:
kwargs.pop(k)
prefix = await self.prefix.aformat(**prefix_kwargs)
@ -190,7 +190,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables
}
for k in suffix_kwargs.keys():
for k in suffix_kwargs:
kwargs.pop(k)
suffix = await self.suffix.aformat(
**suffix_kwargs,

View File

@ -164,10 +164,7 @@ def _load_prompt_from_file(
) -> BasePromptTemplate:
"""Load prompt from file."""
# Convert file to a Path object.
if isinstance(file, str):
file_path = Path(file)
else:
file_path = file
file_path = Path(file) if isinstance(file, str) else file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path, encoding=encoding) as f:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import collections
import contextlib
import functools
import inspect
import threading
@ -2466,10 +2467,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
A JSON-serializable representation of the Runnable.
"""
dumped = super().to_json()
try:
with contextlib.suppress(Exception):
dumped["name"] = self.get_name()
except Exception:
pass
return dumped
def configurable_fields(
@ -2763,9 +2762,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: list[Runnable] = []
if not steps:
if first is not None and last is not None:
steps_flat = [first] + (middle or []) + [last]
if not steps and first is not None and last is not None:
steps_flat = [first] + (middle or []) + [last]
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
@ -4180,12 +4178,9 @@ class RunnableGenerator(Runnable[Input, Output]):
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
final: Optional[Output] = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
final = output if final is None else final + output # type: ignore[operator]
return cast(Output, final)
def atransform(
@ -4215,12 +4210,9 @@ class RunnableGenerator(Runnable[Input, Output]):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
final: Optional[Output] = None
async for output in self.astream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
final = output if final is None else final + output # type: ignore[operator]
return cast(Output, final)

View File

@ -139,11 +139,11 @@ def _set_config_context(config: RunnableConfig) -> None:
None,
)
)
and (run := tracer.run_map.get(str(parent_run_id)))
):
if run := tracer.run_map.get(str(parent_run_id)):
from langsmith.run_helpers import _set_tracing_context
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run})
_set_tracing_context({"parent": run})
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -99,24 +99,15 @@ class AsciiCanvas:
self.point(x0, y0, char)
elif abs(dx) >= abs(dy):
for x in range(x0, x1 + 1):
if dx == 0:
y = y0
else:
y = y0 + int(round((x - x0) * dy / float(dx)))
y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
self.point(x, y, char)
elif y0 < y1:
for y in range(y0, y1 + 1):
if dy == 0:
x = x0
else:
x = x0 + int(round((y - y0) * dx / float(dy)))
x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
self.point(x, y, char)
else:
for y in range(y1, y0 + 1):
if dy == 0:
x = x0
else:
x = x1 + int(round((y - y1) * dx / float(dy)))
x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
self.point(x, y, char)
def text(self, x: int, y: int, text: str) -> None:

View File

@ -131,10 +131,7 @@ def draw_mermaid(
else:
edge_label = f" -- &nbsp;{edge_data}&nbsp; --> "
else:
if edge.conditional:
edge_label = " -.-> "
else:
edge_label = " --> "
edge_label = " -.-> " if edge.conditional else " --> "
mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}"
@ -142,7 +139,7 @@ def draw_mermaid(
)
# Recursively add nested subgraphs
for nested_prefix in edge_groups.keys():
for nested_prefix in edge_groups:
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
continue
add_subgraph(edge_groups[nested_prefix], nested_prefix)
@ -154,7 +151,7 @@ def draw_mermaid(
add_subgraph(edge_groups.get("", []), "")
# Add remaining subgraphs
for prefix in edge_groups.keys():
for prefix in edge_groups:
if ":" in prefix or prefix == "":
continue
add_subgraph(edge_groups[prefix], prefix)

View File

@ -496,12 +496,9 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
Returns:
Optional[Addable]: The result of adding the addable objects.
"""
final = None
final: Optional[Addable] = None
for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
final = chunk if final is None else final + chunk
return final
@ -514,12 +511,9 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
Returns:
Optional[Addable]: The result of adding the addable objects.
"""
final = None
final: Optional[Addable] = None
async for chunk in addables:
if final is None:
final = chunk
else:
final = final + chunk
final = chunk if final is None else final + chunk
return final
@ -642,9 +636,7 @@ def get_unique_config_specs(
for id, dupes in grouped:
first = next(dupes)
others = list(dupes)
if len(others) == 0:
unique.append(first)
elif all(o == first for o in others):
if len(others) == 0 or all(o == first for o in others):
unique.append(first)
else:
raise ValueError(

View File

@ -258,7 +258,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
if prefix is None:
yield from self.store.keys()
else:
for key in self.store.keys():
for key in self.store:
if key.startswith(prefix):
yield key
@ -272,10 +272,10 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
AsyncIterator[str]: An async iterator over keys that match the given prefix.
"""
if prefix is None:
for key in self.store.keys():
for key in self.store:
yield key
else:
for key in self.store.keys():
for key in self.store:
if key.startswith(prefix):
yield key

View File

@ -19,18 +19,24 @@ class Visitor(ABC):
"""Allowed operators for the visitor."""
def _validate_func(self, func: Union[Operator, Comparator]) -> None:
if isinstance(func, Operator) and self.allowed_operators is not None:
if func not in self.allowed_operators:
raise ValueError(
f"Received disallowed operator {func}. Allowed "
f"comparators are {self.allowed_operators}"
)
if isinstance(func, Comparator) and self.allowed_comparators is not None:
if func not in self.allowed_comparators:
raise ValueError(
f"Received disallowed comparator {func}. Allowed "
f"comparators are {self.allowed_comparators}"
)
if (
isinstance(func, Operator)
and self.allowed_operators is not None
and func not in self.allowed_operators
):
raise ValueError(
f"Received disallowed operator {func}. Allowed "
f"comparators are {self.allowed_operators}"
)
if (
isinstance(func, Comparator)
and self.allowed_comparators is not None
and func not in self.allowed_comparators
):
raise ValueError(
f"Received disallowed comparator {func}. Allowed "
f"comparators are {self.allowed_comparators}"
)
@abstractmethod
def visit_operation(self, operation: Operation) -> Any:

View File

@ -248,11 +248,8 @@ def create_schema_from_function(
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
# Let's ignore `self` and `cls` arguments for class and instance methods
if func.__qualname__ and "." in func.__qualname__:
# Then it likely belongs in a class namespace
in_class = True
else:
in_class = False
# If qualified name has a ".", then it likely belongs in a class namespace
in_class = bool(func.__qualname__ and "." in func.__qualname__)
has_args = False
has_kwargs = False
@ -289,12 +286,10 @@ def create_schema_from_function(
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = []
for field in get_fields(inferred_model):
if not has_args:
if field == "args":
continue
if not has_kwargs:
if field == "kwargs":
continue
if not has_args and field == "args":
continue
if not has_kwargs and field == "kwargs":
continue
if field == "v__duplicate_kwargs": # Internal pydantic field
continue
@ -422,12 +417,15 @@ class ChildTool(BaseTool):
def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool."""
if "args_schema" in kwargs and kwargs["args_schema"] is not None:
if not is_basemodel_subclass(kwargs["args_schema"]):
raise TypeError(
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
)
if (
"args_schema" in kwargs
and kwargs["args_schema"] is not None
and not is_basemodel_subclass(kwargs["args_schema"])
):
raise TypeError(
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
)
super().__init__(**kwargs)
model_config = ConfigDict(
@ -840,10 +838,7 @@ def _handle_tool_error(
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
if isinstance(flag, bool):
if e.args:
content = e.args[0]
else:
content = "Tool execution error"
content = e.args[0] if e.args else "Tool execution error"
elif isinstance(flag, str):
content = flag
elif callable(flag):
@ -902,12 +897,11 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content."""
if isinstance(obj, str):
return True
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj):
return True
else:
return False
return (
isinstance(obj, str)
or isinstance(obj, list)
and all(_is_message_content_block(e) for e in obj)
)
def _is_message_content_block(obj: Any) -> bool:

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
@ -814,10 +815,7 @@ async def _astream_events_implementation_v1(
data: EventData = {}
log_entry: LogEntry = run_log.state["logs"][path]
if log_entry["end_time"] is None:
if log_entry["streamed_output"]:
event_type = "stream"
else:
event_type = "start"
event_type = "stream" if log_entry["streamed_output"] else "start"
else:
event_type = "end"
@ -983,14 +981,15 @@ async def _astream_events_implementation_v2(
yield event
continue
if event["run_id"] == first_event_run_id and event["event"].endswith(
"_end"
# If it's the end event corresponding to the root runnable
# we dont include the input in the event since it's guaranteed
# to be included in the first event.
if (
event["run_id"] == first_event_run_id
and event["event"].endswith("_end")
and "input" in event["data"]
):
# If it's the end event corresponding to the root runnable
# we dont include the input in the event since it's guaranteed
# to be included in the first event.
if "input" in event["data"]:
del event["data"]["input"]
del event["data"]["input"]
yield event
except asyncio.CancelledError as exc:
@ -1001,7 +1000,5 @@ async def _astream_events_implementation_v2(
# Cancel the task if it's still running
task.cancel()
# Await it anyway, to run any cleanup code, and propagate any exceptions
try:
with contextlib.suppress(asyncio.CancelledError):
await task
except asyncio.CancelledError:
pass

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import contextlib
import copy
import threading
from collections import defaultdict
@ -253,18 +254,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
"""
async for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
if not self.send(
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if (
run_id != self.root_id
and (key := self._key_map_by_run_id.get(run_id))
and (
not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
)
)
):
break
yield chunk
@ -280,18 +285,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
"""
for chunk in output:
# root run is handled in .astream_log()
if run_id != self.root_id:
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if key := self._key_map_by_run_id.get(run_id):
if not self.send(
# if we can't find the run silently ignore
# eg. because this run wasn't included in the log
if (
run_id != self.root_id
and (key := self._key_map_by_run_id.get(run_id))
and (
not self.send(
{
"op": "add",
"path": f"/logs/{key}/streamed_output/-",
"value": chunk,
}
):
break
)
)
):
break
yield chunk
@ -439,9 +448,8 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self.send(*ops)
finally:
if run.id == self.root_id:
if self.auto_close:
self.send_stream.close()
if run.id == self.root_id and self.auto_close:
self.send_stream.close()
def _on_llm_new_token(
self,
@ -662,7 +670,5 @@ async def _astream_log_implementation(
yield state
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
with contextlib.suppress(asyncio.CancelledError):
await task
except asyncio.CancelledError:
pass

View File

@ -29,9 +29,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
merged = left.copy()
for right in others:
for right_k, right_v in right.items():
if right_k not in merged:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
if right_k not in merged or right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is None:
continue

View File

@ -43,14 +43,10 @@ def get_from_dict_or_env(
if k in data and data[k]:
return data[k]
if isinstance(key, str):
if key in data and data[key]:
return data[key]
if isinstance(key, str) and key in data and data[key]:
return data[key]
if isinstance(key, (list, tuple)):
key_for_err = key[0]
else:
key_for_err = key
key_for_err = key[0] if isinstance(key, (list, tuple)) else key
return get_from_env(key_for_err, env_key, default=default)

View File

@ -64,7 +64,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
new_kv = {}
for k, v in kv.items():
if k == "title":
if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys():
if isinstance(v, dict) and prev_key == "properties" and "title" in v:
new_kv[k] = _rm_titles(v, k)
else:
continue

View File

@ -139,11 +139,8 @@ def parse_json_markdown(
match = _json_markdown_re.search(json_string)
# If no match found, assume the entire string is a JSON string
if match is None:
json_str = json_string
else:
# If match found, use the content within the backticks
json_str = match.group(2)
# Else, use the content within the backticks
json_str = json_string if match is None else match.group(2)
return _parse_json(json_str, parser=parser)

View File

@ -80,12 +80,9 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
padding = literal.split("\n")[-1]
# If all the characters since the last newline are spaces
if padding.isspace() or padding == "":
# Then the next tag could be a standalone
return True
else:
# Otherwise it can't be
return False
# Then the next tag could be a standalone
# Otherwise it can't be
return padding.isspace() or padding == ""
else:
return False
@ -107,10 +104,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
on_newline = template.split("\n", 1)
# If the stuff to the right of us are spaces we're a standalone
if on_newline[0].isspace() or not on_newline[0]:
return True
else:
return False
return on_newline[0].isspace() or not on_newline[0]
# If we're a tag can't be a standalone
else:
@ -174,14 +168,18 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
)
# If we might be a no html escape tag
elif tag_type == "no escape?":
elif (
# If we might be a no html escape tag
tag_type == "no escape?"
# And we have a third curly brace
# (And are using curly braces as delimiters)
if l_del == "{{" and r_del == "}}" and template.startswith("}"):
# Then we are a no html escape tag
template = template[1:]
tag_type = "no escape"
and l_del == "{{"
and r_del == "}}"
and template.startswith("}")
):
# Then we are a no html escape tag
template = template[1:]
tag_type = "no escape"
# Strip the whitespace off the key and return
return ((tag_type, tag.strip()), template)

View File

@ -179,22 +179,27 @@ def pre_init(func: Callable) -> Any:
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if hasattr(cls, "Config"):
if hasattr(cls.Config, "allow_population_by_field_name"):
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if (
hasattr(cls, "Config")
and hasattr(cls.Config, "allow_population_by_field_name")
and cls.Config.allow_population_by_field_name
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if (
hasattr(cls, "model_config")
and cls.model_config.get("populate_by_name")
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
if (
name not in values or values[name] is None
) and not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
# Call the decorated function
return func(cls, values)

View File

@ -332,9 +332,8 @@ def from_env(
for k in key:
if k in os.environ:
return os.environ[k]
if isinstance(key, str):
if key in os.environ:
return os.environ[key]
if isinstance(key, str) and key in os.environ:
return os.environ[key]
if isinstance(default, (str, type(None))):
return default
@ -395,9 +394,8 @@ def secret_from_env(
for k in key:
if k in os.environ:
return SecretStr(os.environ[k])
if isinstance(key, str):
if key in os.environ:
return SecretStr(os.environ[key])
if isinstance(key, str) and key in os.environ:
return SecretStr(os.environ[key])
if isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):

View File

@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras]
[tool.ruff.lint]
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",]
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "SIM", "T201", "UP",]
ignore = [ "UP007",]
[tool.coverage.run]

View File

@ -148,7 +148,7 @@ def test_pydantic_output_parser() -> None:
result = pydantic_parser.parse(DEF_RESULT)
print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == result
assert result == DEF_EXPECTED_RESULT
assert pydantic_parser.OutputType is TestModel

View File

@ -115,10 +115,7 @@ def _normalize_schema(obj: Any) -> dict[str, Any]:
Args:
obj: The object to generate the schema for
"""
if isinstance(obj, BaseModel):
data = obj.model_json_schema()
else:
data = obj
data = obj.model_json_schema() if isinstance(obj, BaseModel) else obj
remove_all_none_default(data)
replace_all_of_with_ref(data)
_remove_enum(data)

View File

@ -3345,9 +3345,9 @@ def test_bind_with_lambda() -> None:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == runnable.invoke({})
assert runnable.invoke({}) == 4
chunks = list(runnable.stream({}))
assert [4] == chunks
assert chunks == [4]
async def test_bind_with_lambda_async() -> None:
@ -3355,9 +3355,9 @@ async def test_bind_with_lambda_async() -> None:
return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == await runnable.ainvoke({})
assert await runnable.ainvoke({}) == 4
chunks = [item async for item in runnable.astream({})]
assert [4] == chunks
assert chunks == [4]
def test_deep_stream() -> None:
@ -5140,13 +5140,10 @@ async def test_astream_log_deep_copies() -> None:
chain = RunnableLambda(add_one)
chunks = []
final_output = None
final_output: Optional[RunLogPatch] = None
async for chunk in chain.astream_log(1):
chunks.append(chunk)
if final_output is None:
final_output = chunk
else:
final_output = final_output + chunk
final_output = chunk if final_output is None else final_output + chunk
run_log = _get_run_log(chunks)
state = run_log.state.copy()

View File

@ -209,9 +209,11 @@ def test_tracing_enable_disable(
get_env_var.cache_clear()
env_on = env == "true"
with patch.dict("os.environ", {"LANGSMITH_TRACING": env}):
with tracing_context(enabled=enabled):
RunnableLambda(my_func).invoke(1)
with (
patch.dict("os.environ", {"LANGSMITH_TRACING": env}),
tracing_context(enabled=enabled),
):
RunnableLambda(my_func).invoke(1)
mock_posts = _get_posts(mock_client_)
if enabled is True:

View File

@ -1,6 +1,6 @@
import unittest
import uuid
from typing import Union
from typing import Optional, Union
import pytest
@ -10,6 +10,7 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
@ -630,14 +631,11 @@ def test_tool_calls_merge() -> None:
{"content": ""},
]
final = None
final: Optional[BaseMessageChunk] = None
for chunk in chunks:
msg = AIMessageChunk(**chunk)
if final is None:
final = msg
else:
final = final + msg
final = msg if final is None else final + msg
assert final == AIMessageChunk(
content="",

View File

@ -133,25 +133,24 @@ class LangChainProjectNameTest(unittest.TestCase):
for case in cases:
get_env_var.cache_clear()
get_tracer_project.cache_clear()
with self.subTest(msg=case.test_name):
with pytest.MonkeyPatch.context() as mp:
for k, v in case.envvars.items():
mp.setenv(k, v)
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
for k, v in case.envvars.items():
mp.setenv(k, v)
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
projects = []
client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client)
projects = []
def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("project_name")) # noqa: B023
return unittest.mock.MagicMock()
def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("project_name")) # noqa: B023
return unittest.mock.MagicMock()
client.create_run = mock_create_run
client.create_run = mock_create_run
tracer.on_llm_start(
{"name": "example_1"},
["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert projects == [case.expected_project_name]
tracer.on_llm_start(
{"name": "example_1"},
["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
)
tracer.wait_for_futures()
assert projects == [case.expected_project_name]

View File

@ -513,14 +513,8 @@ def test_tool_outputs() -> None:
def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None:
if use_extension_typed_dict:
typed_dict = ExtensionsTypedDict
else:
typed_dict = TypingTypedDict
if use_extension_annotated:
annotated = TypingAnnotated
else:
annotated = TypingAnnotated
typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict
annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
class SubTool(typed_dict):
"""Subtool docstring"""

View File

@ -116,10 +116,7 @@ def test_check_package_version(
def test_merge_dicts(
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
) -> None:
if isinstance(expected, AbstractContextManager):
err = expected
else:
err = nullcontext()
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
left_copy = deepcopy(left)
right_copy = deepcopy(right)
@ -147,10 +144,7 @@ def test_merge_dicts(
def test_merge_dicts_0_3(
left: dict, right: dict, expected: Union[dict, AbstractContextManager]
) -> None:
if isinstance(expected, AbstractContextManager):
err = expected
else:
err = nullcontext()
err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
left_copy = deepcopy(left)
right_copy = deepcopy(right)