core[patch]: Add ruff rules for flake8-simplify (SIM) (#26848)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Christophe Bornet 2024-09-27 22:13:23 +02:00 committed by GitHub
parent de0b48c41a
commit 7809b31b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 277 additions and 384 deletions

View File

@ -127,10 +127,9 @@ def beta(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the annotation of a class.""" """Finalize the annotation of a class."""
try: # Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
def warn_if_direct_instance( def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any self: Any, *args: Any, **kwargs: Any

View File

@ -198,10 +198,9 @@ def deprecated(
def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: def finalize(wrapper: Callable[..., Any], new_doc: str) -> T:
"""Finalize the deprecation of a class.""" """Finalize the deprecation of a class."""
try: # Can't set new_doc on some extension objects.
with contextlib.suppress(AttributeError):
obj.__doc__ = new_doc obj.__doc__ = new_doc
except AttributeError: # Can't set on some extension objects.
pass
def warn_if_direct_instance( def warn_if_direct_instance(
self: Any, *args: Any, **kwargs: Any self: Any, *args: Any, **kwargs: Any

View File

@ -27,7 +27,7 @@ class FileCallbackHandler(BaseCallbackHandler):
mode: The mode to open the file in. Defaults to "a". mode: The mode to open the file in. Defaults to "a".
color: The color to use for the text. Defaults to None. color: The color to use for the text. Defaults to None.
""" """
self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115
self.color = color self.color = color
def __del__(self) -> None: def __del__(self) -> None:

View File

@ -2252,14 +2252,14 @@ def _configure(
else: else:
parent_run_id_ = inheritable_callbacks.parent_run_id parent_run_id_ = inheritable_callbacks.parent_run_id
# Break ties between the external tracing context and inherited context # Break ties between the external tracing context and inherited context
if parent_run_id is not None: if parent_run_id is not None and (
if parent_run_id_ is None: parent_run_id_ is None
parent_run_id_ = parent_run_id
# If the LC parent has already been reflected # If the LC parent has already been reflected
# in the run tree, we know the run_tree is either the # in the run tree, we know the run_tree is either the
# same parent or a child of the parent. # same parent or a child of the parent.
elif run_tree and str(parent_run_id_) in run_tree.dotted_order: or (run_tree and str(parent_run_id_) in run_tree.dotted_order)
parent_run_id_ = parent_run_id ):
parent_run_id_ = parent_run_id
# Otherwise, we assume the LC context has progressed # Otherwise, we assume the LC context has progressed
# beyond the run tree and we should not inherit the parent. # beyond the run tree and we should not inherit the parent.
callback_manager = callback_manager_cls( callback_manager = callback_manager_cls(

View File

@ -40,12 +40,11 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818
send_to_llm: bool = False, send_to_llm: bool = False,
): ):
super().__init__(error) super().__init__(error)
if send_to_llm: if send_to_llm and (observation is None or llm_output is None):
if observation is None or llm_output is None: raise ValueError(
raise ValueError( "Arguments 'observation' & 'llm_output'"
"Arguments 'observation' & 'llm_output'" " are required if 'send_to_llm' is True"
" are required if 'send_to_llm' is True" )
)
self.observation = observation self.observation = observation
self.llm_output = llm_output self.llm_output = llm_output
self.send_to_llm = send_to_llm self.send_to_llm = send_to_llm

View File

@ -92,7 +92,7 @@ class _HashedDocument(Document):
values["metadata_hash"] = metadata_hash values["metadata_hash"] = metadata_hash
values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash)) values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash))
_uid = values.get("uid", None) _uid = values.get("uid")
if _uid is None: if _uid is None:
values["uid"] = values["hash_"] values["uid"] = values["hash_"]

View File

@ -802,10 +802,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if isinstance(self.cache, BaseCache): llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
# We should check the cache unless it's explicitly set to False # We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache # A None cache means we should use the default global cache
# if it's configured. # if it's configured.
@ -879,10 +876,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
if isinstance(self.cache, BaseCache): llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache()
llm_cache = self.cache
else:
llm_cache = get_llm_cache()
# We should check the cache unless it's explicitly set to False # We should check the cache unless it's explicitly set to False
# A None cache means we should use the default global cache # A None cache means we should use the default global cache
# if it's configured. # if it's configured.
@ -1054,10 +1048,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def predict( def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs) result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
if isinstance(result.content, str): if isinstance(result.content, str):
return result.content return result.content
@ -1072,20 +1063,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop, **kwargs) return self(messages, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
async def apredict( async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
result = await self._call_async( result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs [HumanMessage(content=text)], stop=_stop, **kwargs
) )
@ -1102,10 +1087,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop, **kwargs) return await self._call_async(messages, stop=_stop, **kwargs)
@property @property
@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
if not isinstance(serialized, dict): if not isinstance(serialized, dict):
return return
if "type" in serialized and serialized["type"] == "not_implemented": if (
if "repr" in serialized: "type" in serialized
del serialized["repr"] and serialized["type"] == "not_implemented"
and "repr" in serialized
):
del serialized["repr"]
if "graph" in serialized: if "graph" in serialized:
del serialized["graph"] del serialized["graph"]

View File

@ -194,10 +194,7 @@ class GenericFakeChatModel(BaseChatModel):
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
message = next(self.messages) message = next(self.messages)
if isinstance(message, str): message_ = AIMessage(content=message) if isinstance(message, str) else message
message_ = AIMessage(content=message)
else:
message_ = message
generation = ChatGeneration(message=message_) generation = ChatGeneration(message=message_)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])

View File

@ -1305,10 +1305,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
def predict( def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
return self(text, stop=_stop, **kwargs) return self(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="invoke", removal="1.0") @deprecated("0.1.7", alternative="invoke", removal="1.0")
@ -1320,10 +1317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
content = self(text, stop=_stop, **kwargs) content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
@ -1331,10 +1325,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
async def apredict( async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str: ) -> str:
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
return await self._call_async(text, stop=_stop, **kwargs) return await self._call_async(text, stop=_stop, **kwargs)
@deprecated("0.1.7", alternative="ainvoke", removal="1.0") @deprecated("0.1.7", alternative="ainvoke", removal="1.0")
@ -1346,10 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
**kwargs: Any, **kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: _stop = None if stop is None else list(stop)
_stop = None
else:
_stop = list(stop)
content = await self._call_async(text, stop=_stop, **kwargs) content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
@ -1384,10 +1372,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
llm.save(file_path="path/llm.yaml") llm.save(file_path="path/llm.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -86,9 +86,9 @@ class Reviver:
def __call__(self, value: dict[str, Any]) -> Any: def __call__(self, value: dict[str, Any]) -> Any:
if ( if (
value.get("lc", None) == 1 value.get("lc") == 1
and value.get("type", None) == "secret" and value.get("type") == "secret"
and value.get("id", None) is not None and value.get("id") is not None
): ):
[key] = value["id"] [key] = value["id"]
if key in self.secrets_map: if key in self.secrets_map:
@ -99,9 +99,9 @@ class Reviver:
raise KeyError(f'Missing key "{key}" in load(secrets_map)') raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if ( if (
value.get("lc", None) == 1 value.get("lc") == 1
and value.get("type", None) == "not_implemented" and value.get("type") == "not_implemented"
and value.get("id", None) is not None and value.get("id") is not None
): ):
raise NotImplementedError( raise NotImplementedError(
"Trying to load an object that doesn't implement " "Trying to load an object that doesn't implement "
@ -109,17 +109,18 @@ class Reviver:
) )
if ( if (
value.get("lc", None) == 1 value.get("lc") == 1
and value.get("type", None) == "constructor" and value.get("type") == "constructor"
and value.get("id", None) is not None and value.get("id") is not None
): ):
[*namespace, name] = value["id"] [*namespace, name] = value["id"]
mapping_key = tuple(value["id"]) mapping_key = tuple(value["id"])
if namespace[0] not in self.valid_namespaces: if (
raise ValueError(f"Invalid namespace: {value}") namespace[0] not in self.valid_namespaces
# The root namespace ["langchain"] is not a valid identifier. # The root namespace ["langchain"] is not a valid identifier.
elif namespace == ["langchain"]: or namespace == ["langchain"]
):
raise ValueError(f"Invalid namespace: {value}") raise ValueError(f"Invalid namespace: {value}")
# Has explicit import path. # Has explicit import path.
elif mapping_key in self.import_mappings: elif mapping_key in self.import_mappings:

View File

@ -1,3 +1,4 @@
import contextlib
from abc import ABC from abc import ABC
from typing import ( from typing import (
Any, Any,
@ -238,7 +239,7 @@ class Serializable(BaseModel, ABC):
# include all secrets, even if not specified in kwargs # include all secrets, even if not specified in kwargs
# as these secrets may be passed as an environment variable instead # as these secrets may be passed as an environment variable instead
for key in secrets.keys(): for key in secrets:
secret_value = getattr(self, key, None) or lc_kwargs.get(key) secret_value = getattr(self, key, None) or lc_kwargs.get(key)
if secret_value is not None: if secret_value is not None:
lc_kwargs.update({key: secret_value}) lc_kwargs.update({key: secret_value})
@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
"id": _id, "id": _id,
"repr": None, "repr": None,
} }
try: with contextlib.suppress(Exception):
result["repr"] = repr(obj) result["repr"] = repr(obj)
except Exception:
pass
return result return result

View File

@ -435,23 +435,22 @@ def filter_messages(
messages = convert_to_messages(messages) messages = convert_to_messages(messages)
filtered: list[BaseMessage] = [] filtered: list[BaseMessage] = []
for msg in messages: for msg in messages:
if exclude_names and msg.name in exclude_names: if (
continue (exclude_names and msg.name in exclude_names)
elif exclude_types and _is_message_type(msg, exclude_types): or (exclude_types and _is_message_type(msg, exclude_types))
continue or (exclude_ids and msg.id in exclude_ids)
elif exclude_ids and msg.id in exclude_ids: ):
continue continue
else: else:
pass pass
# default to inclusion when no inclusion criteria given. # default to inclusion when no inclusion criteria given.
if not (include_types or include_ids or include_names): if (
filtered.append(msg) not (include_types or include_ids or include_names)
elif include_names and msg.name in include_names: or (include_names and msg.name in include_names)
filtered.append(msg) or (include_types and _is_message_type(msg, include_types))
elif include_types and _is_message_type(msg, include_types): or (include_ids and msg.id in include_ids)
filtered.append(msg) ):
elif include_ids and msg.id in include_ids:
filtered.append(msg) filtered.append(msg)
else: else:
pass pass
@ -961,10 +960,7 @@ def _last_max_tokens(
while messages and not _is_message_type(messages[-1], end_on): while messages and not _is_message_type(messages[-1], end_on):
messages.pop() messages.pop()
swapped_system = include_system and isinstance(messages[0], SystemMessage) swapped_system = include_system and isinstance(messages[0], SystemMessage)
if swapped_system: reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1]
reversed_ = messages[:1] + messages[1:][::-1]
else:
reversed_ = messages[::-1]
reversed_ = _first_max_tokens( reversed_ = _first_max_tokens(
reversed_, reversed_,

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import contextlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -311,8 +312,6 @@ class BaseOutputParser(
def dict(self, **kwargs: Any) -> dict: def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of output parser.""" """Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs) output_parser_dict = super().dict(**kwargs)
try: with contextlib.suppress(NotImplementedError):
output_parser_dict["_type"] = self._type output_parser_dict["_type"] = self._type
except NotImplementedError:
pass
return output_parser_dict return output_parser_dict

View File

@ -49,12 +49,10 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]:
if PYDANTIC_MAJOR_VERSION == 2: if issubclass(pydantic_object, pydantic.BaseModel):
if issubclass(pydantic_object, pydantic.BaseModel): return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema() elif issubclass(pydantic_object, pydantic.v1.BaseModel):
elif issubclass(pydantic_object, pydantic.v1.BaseModel): return pydantic_object.schema()
return pydantic_object.model_json_schema()
return pydantic_object.model_json_schema()
def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any:
"""Parse the result of an LLM call to a JSON object. """Parse the result of an LLM call to a JSON object.

View File

@ -110,10 +110,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None prev_parsed = None
acc_gen = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
for chunk in input: for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
if isinstance(chunk, BaseMessageChunk): if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict()) message=BaseMessageChunk(**chunk.dict())
@ -121,10 +122,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None: acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
acc_gen = chunk_gen
else:
acc_gen = acc_gen + chunk_gen
parsed = self.parse_result([acc_gen], partial=True) parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
@ -138,10 +136,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
self, input: AsyncIterator[Union[str, BaseMessage]] self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
prev_parsed = None prev_parsed = None
acc_gen = None acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None
async for chunk in input: async for chunk in input:
chunk_gen: Union[GenerationChunk, ChatGenerationChunk]
if isinstance(chunk, BaseMessageChunk): if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk) chunk_gen = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage): elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk( chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict()) message=BaseMessageChunk(**chunk.dict())
@ -149,10 +148,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
else: else:
chunk_gen = GenerationChunk(text=chunk) chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None: acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator]
acc_gen = chunk_gen
else:
acc_gen = acc_gen + chunk_gen
parsed = await self.aparse_result([acc_gen], partial=True) parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:

View File

@ -1,3 +1,4 @@
import contextlib
import re import re
import xml import xml
import xml.etree.ElementTree as ET # noqa: N817 import xml.etree.ElementTree as ET # noqa: N817
@ -131,11 +132,9 @@ class _StreamingParser:
Raises: Raises:
xml.etree.ElementTree.ParseError: If the XML is not well-formed. xml.etree.ElementTree.ParseError: If the XML is not well-formed.
""" """
try: # Ignore ParseError. This will ignore any incomplete XML at the end of the input
with contextlib.suppress(xml.etree.ElementTree.ParseError):
self.pull_parser.close() self.pull_parser.close()
except xml.etree.ElementTree.ParseError:
# Ignore. This will ignore any incomplete XML at the end of the input
pass
class XMLOutputParser(BaseTransformOutputParser): class XMLOutputParser(BaseTransformOutputParser):

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import contextlib
import json import json
import typing import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -319,10 +320,8 @@ class BasePromptTemplate(
NotImplementedError: If the prompt type is not implemented. NotImplementedError: If the prompt type is not implemented.
""" """
prompt_dict = super().model_dump(**kwargs) prompt_dict = super().model_dump(**kwargs)
try: with contextlib.suppress(NotImplementedError):
prompt_dict["_type"] = self._prompt_type prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict return prompt_dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -350,10 +349,7 @@ class BasePromptTemplate(
raise NotImplementedError(f"Prompt {self} does not support saving.") raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -59,8 +59,8 @@ class _FewShotPromptTemplateMixin(BaseModel):
ValueError: If neither or both examples and example_selector are provided. ValueError: If neither or both examples and example_selector are provided.
ValueError: If both examples and example_selector are provided. ValueError: If both examples and example_selector are provided.
""" """
examples = values.get("examples", None) examples = values.get("examples")
example_selector = values.get("example_selector", None) example_selector = values.get("example_selector")
if examples and example_selector: if examples and example_selector:
raise ValueError( raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided" "Only one of 'examples' and 'example_selector' should be provided"

View File

@ -51,8 +51,8 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
@classmethod @classmethod
def check_examples_and_selector(cls, values: dict) -> Any: def check_examples_and_selector(cls, values: dict) -> Any:
"""Check that one and only one of examples/example_selector are provided.""" """Check that one and only one of examples/example_selector are provided."""
examples = values.get("examples", None) examples = values.get("examples")
example_selector = values.get("example_selector", None) example_selector = values.get("example_selector")
if examples and example_selector: if examples and example_selector:
raise ValueError( raise ValueError(
"Only one of 'examples' and 'example_selector' should be provided" "Only one of 'examples' and 'example_selector' should be provided"
@ -138,7 +138,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix_kwargs = { prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables k: v for k, v in kwargs.items() if k in self.prefix.input_variables
} }
for k in prefix_kwargs.keys(): for k in prefix_kwargs:
kwargs.pop(k) kwargs.pop(k)
prefix = self.prefix.format(**prefix_kwargs) prefix = self.prefix.format(**prefix_kwargs)
@ -146,7 +146,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
suffix_kwargs = { suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables k: v for k, v in kwargs.items() if k in self.suffix.input_variables
} }
for k in suffix_kwargs.keys(): for k in suffix_kwargs:
kwargs.pop(k) kwargs.pop(k)
suffix = self.suffix.format( suffix = self.suffix.format(
**suffix_kwargs, **suffix_kwargs,
@ -182,7 +182,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
prefix_kwargs = { prefix_kwargs = {
k: v for k, v in kwargs.items() if k in self.prefix.input_variables k: v for k, v in kwargs.items() if k in self.prefix.input_variables
} }
for k in prefix_kwargs.keys(): for k in prefix_kwargs:
kwargs.pop(k) kwargs.pop(k)
prefix = await self.prefix.aformat(**prefix_kwargs) prefix = await self.prefix.aformat(**prefix_kwargs)
@ -190,7 +190,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
suffix_kwargs = { suffix_kwargs = {
k: v for k, v in kwargs.items() if k in self.suffix.input_variables k: v for k, v in kwargs.items() if k in self.suffix.input_variables
} }
for k in suffix_kwargs.keys(): for k in suffix_kwargs:
kwargs.pop(k) kwargs.pop(k)
suffix = await self.suffix.aformat( suffix = await self.suffix.aformat(
**suffix_kwargs, **suffix_kwargs,

View File

@ -164,10 +164,7 @@ def _load_prompt_from_file(
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Load prompt from file.""" """Load prompt from file."""
# Convert file to a Path object. # Convert file to a Path object.
if isinstance(file, str): file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix == ".json": if file_path.suffix == ".json":
with open(file_path, encoding=encoding) as f: with open(file_path, encoding=encoding) as f:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import collections import collections
import contextlib
import functools import functools
import inspect import inspect
import threading import threading
@ -2466,10 +2467,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
A JSON-serializable representation of the Runnable. A JSON-serializable representation of the Runnable.
""" """
dumped = super().to_json() dumped = super().to_json()
try: with contextlib.suppress(Exception):
dumped["name"] = self.get_name() dumped["name"] = self.get_name()
except Exception:
pass
return dumped return dumped
def configurable_fields( def configurable_fields(
@ -2763,9 +2762,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
ValueError: If the sequence has less than 2 steps. ValueError: If the sequence has less than 2 steps.
""" """
steps_flat: list[Runnable] = [] steps_flat: list[Runnable] = []
if not steps: if not steps and first is not None and last is not None:
if first is not None and last is not None: steps_flat = [first] + (middle or []) + [last]
steps_flat = [first] + (middle or []) + [last]
for step in steps: for step in steps:
if isinstance(step, RunnableSequence): if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps) steps_flat.extend(step.steps)
@ -4180,12 +4178,9 @@ class RunnableGenerator(Runnable[Input, Output]):
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
final = None final: Optional[Output] = None
for output in self.stream(input, config, **kwargs): for output in self.stream(input, config, **kwargs):
if final is None: final = output if final is None else final + output # type: ignore[operator]
final = output
else:
final = final + output
return cast(Output, final) return cast(Output, final)
def atransform( def atransform(
@ -4215,12 +4210,9 @@ class RunnableGenerator(Runnable[Input, Output]):
async def ainvoke( async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output: ) -> Output:
final = None final: Optional[Output] = None
async for output in self.astream(input, config, **kwargs): async for output in self.astream(input, config, **kwargs):
if final is None: final = output if final is None else final + output # type: ignore[operator]
final = output
else:
final = final + output
return cast(Output, final) return cast(Output, final)

View File

@ -139,11 +139,11 @@ def _set_config_context(config: RunnableConfig) -> None:
None, None,
) )
) )
and (run := tracer.run_map.get(str(parent_run_id)))
): ):
if run := tracer.run_map.get(str(parent_run_id)): from langsmith.run_helpers import _set_tracing_context
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run}) _set_tracing_context({"parent": run})
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:

View File

@ -99,24 +99,15 @@ class AsciiCanvas:
self.point(x0, y0, char) self.point(x0, y0, char)
elif abs(dx) >= abs(dy): elif abs(dx) >= abs(dy):
for x in range(x0, x1 + 1): for x in range(x0, x1 + 1):
if dx == 0: y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx)))
y = y0
else:
y = y0 + int(round((x - x0) * dy / float(dx)))
self.point(x, y, char) self.point(x, y, char)
elif y0 < y1: elif y0 < y1:
for y in range(y0, y1 + 1): for y in range(y0, y1 + 1):
if dy == 0: x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy)))
x = x0
else:
x = x0 + int(round((y - y0) * dx / float(dy)))
self.point(x, y, char) self.point(x, y, char)
else: else:
for y in range(y1, y0 + 1): for y in range(y1, y0 + 1):
if dy == 0: x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy)))
x = x0
else:
x = x1 + int(round((y - y1) * dx / float(dy)))
self.point(x, y, char) self.point(x, y, char)
def text(self, x: int, y: int, text: str) -> None: def text(self, x: int, y: int, text: str) -> None:

View File

@ -131,10 +131,7 @@ def draw_mermaid(
else: else:
edge_label = f" -- &nbsp;{edge_data}&nbsp; --> " edge_label = f" -- &nbsp;{edge_data}&nbsp; --> "
else: else:
if edge.conditional: edge_label = " -.-> " if edge.conditional else " --> "
edge_label = " -.-> "
else:
edge_label = " --> "
mermaid_graph += ( mermaid_graph += (
f"\t{_escape_node_label(source)}{edge_label}" f"\t{_escape_node_label(source)}{edge_label}"
@ -142,7 +139,7 @@ def draw_mermaid(
) )
# Recursively add nested subgraphs # Recursively add nested subgraphs
for nested_prefix in edge_groups.keys(): for nested_prefix in edge_groups:
if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix:
continue continue
add_subgraph(edge_groups[nested_prefix], nested_prefix) add_subgraph(edge_groups[nested_prefix], nested_prefix)
@ -154,7 +151,7 @@ def draw_mermaid(
add_subgraph(edge_groups.get("", []), "") add_subgraph(edge_groups.get("", []), "")
# Add remaining subgraphs # Add remaining subgraphs
for prefix in edge_groups.keys(): for prefix in edge_groups:
if ":" in prefix or prefix == "": if ":" in prefix or prefix == "":
continue continue
add_subgraph(edge_groups[prefix], prefix) add_subgraph(edge_groups[prefix], prefix)

View File

@ -496,12 +496,9 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
Returns: Returns:
Optional[Addable]: The result of adding the addable objects. Optional[Addable]: The result of adding the addable objects.
""" """
final = None final: Optional[Addable] = None
for chunk in addables: for chunk in addables:
if final is None: final = chunk if final is None else final + chunk
final = chunk
else:
final = final + chunk
return final return final
@ -514,12 +511,9 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
Returns: Returns:
Optional[Addable]: The result of adding the addable objects. Optional[Addable]: The result of adding the addable objects.
""" """
final = None final: Optional[Addable] = None
async for chunk in addables: async for chunk in addables:
if final is None: final = chunk if final is None else final + chunk
final = chunk
else:
final = final + chunk
return final return final
@ -642,9 +636,7 @@ def get_unique_config_specs(
for id, dupes in grouped: for id, dupes in grouped:
first = next(dupes) first = next(dupes)
others = list(dupes) others = list(dupes)
if len(others) == 0: if len(others) == 0 or all(o == first for o in others):
unique.append(first)
elif all(o == first for o in others):
unique.append(first) unique.append(first)
else: else:
raise ValueError( raise ValueError(

View File

@ -258,7 +258,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
if prefix is None: if prefix is None:
yield from self.store.keys() yield from self.store.keys()
else: else:
for key in self.store.keys(): for key in self.store:
if key.startswith(prefix): if key.startswith(prefix):
yield key yield key
@ -272,10 +272,10 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
AsyncIterator[str]: An async iterator over keys that match the given prefix. AsyncIterator[str]: An async iterator over keys that match the given prefix.
""" """
if prefix is None: if prefix is None:
for key in self.store.keys(): for key in self.store:
yield key yield key
else: else:
for key in self.store.keys(): for key in self.store:
if key.startswith(prefix): if key.startswith(prefix):
yield key yield key

View File

@ -19,18 +19,24 @@ class Visitor(ABC):
"""Allowed operators for the visitor.""" """Allowed operators for the visitor."""
def _validate_func(self, func: Union[Operator, Comparator]) -> None: def _validate_func(self, func: Union[Operator, Comparator]) -> None:
if isinstance(func, Operator) and self.allowed_operators is not None: if (
if func not in self.allowed_operators: isinstance(func, Operator)
raise ValueError( and self.allowed_operators is not None
f"Received disallowed operator {func}. Allowed " and func not in self.allowed_operators
f"comparators are {self.allowed_operators}" ):
) raise ValueError(
if isinstance(func, Comparator) and self.allowed_comparators is not None: f"Received disallowed operator {func}. Allowed "
if func not in self.allowed_comparators: f"comparators are {self.allowed_operators}"
raise ValueError( )
f"Received disallowed comparator {func}. Allowed " if (
f"comparators are {self.allowed_comparators}" isinstance(func, Comparator)
) and self.allowed_comparators is not None
and func not in self.allowed_comparators
):
raise ValueError(
f"Received disallowed comparator {func}. Allowed "
f"comparators are {self.allowed_comparators}"
)
@abstractmethod @abstractmethod
def visit_operation(self, operation: Operation) -> Any: def visit_operation(self, operation: Operation) -> Any:

View File

@ -248,11 +248,8 @@ def create_schema_from_function(
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
# Let's ignore `self` and `cls` arguments for class and instance methods # Let's ignore `self` and `cls` arguments for class and instance methods
if func.__qualname__ and "." in func.__qualname__: # If qualified name has a ".", then it likely belongs in a class namespace
# Then it likely belongs in a class namespace in_class = bool(func.__qualname__ and "." in func.__qualname__)
in_class = True
else:
in_class = False
has_args = False has_args = False
has_kwargs = False has_kwargs = False
@ -289,12 +286,10 @@ def create_schema_from_function(
# Pydantic adds placeholder virtual fields we need to strip # Pydantic adds placeholder virtual fields we need to strip
valid_properties = [] valid_properties = []
for field in get_fields(inferred_model): for field in get_fields(inferred_model):
if not has_args: if not has_args and field == "args":
if field == "args": continue
continue if not has_kwargs and field == "kwargs":
if not has_kwargs: continue
if field == "kwargs":
continue
if field == "v__duplicate_kwargs": # Internal pydantic field if field == "v__duplicate_kwargs": # Internal pydantic field
continue continue
@ -422,12 +417,15 @@ class ChildTool(BaseTool):
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
"""Initialize the tool.""" """Initialize the tool."""
if "args_schema" in kwargs and kwargs["args_schema"] is not None: if (
if not is_basemodel_subclass(kwargs["args_schema"]): "args_schema" in kwargs
raise TypeError( and kwargs["args_schema"] is not None
f"args_schema must be a subclass of pydantic BaseModel. " and not is_basemodel_subclass(kwargs["args_schema"])
f"Got: {kwargs['args_schema']}." ):
) raise TypeError(
f"args_schema must be a subclass of pydantic BaseModel. "
f"Got: {kwargs['args_schema']}."
)
super().__init__(**kwargs) super().__init__(**kwargs)
model_config = ConfigDict( model_config = ConfigDict(
@ -840,10 +838,7 @@ def _handle_tool_error(
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str: ) -> str:
if isinstance(flag, bool): if isinstance(flag, bool):
if e.args: content = e.args[0] if e.args else "Tool execution error"
content = e.args[0]
else:
content = "Tool execution error"
elif isinstance(flag, str): elif isinstance(flag, str):
content = flag content = flag
elif callable(flag): elif callable(flag):
@ -902,12 +897,11 @@ def _format_output(
def _is_message_content_type(obj: Any) -> bool: def _is_message_content_type(obj: Any) -> bool:
"""Check for OpenAI or Anthropic format tool message content.""" """Check for OpenAI or Anthropic format tool message content."""
if isinstance(obj, str): return (
return True isinstance(obj, str)
elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj): or isinstance(obj, list)
return True and all(_is_message_content_block(e) for e in obj)
else: )
return False
def _is_message_content_block(obj: Any) -> bool: def _is_message_content_block(obj: Any) -> bool:

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import logging import logging
from collections.abc import AsyncIterator, Iterator, Sequence from collections.abc import AsyncIterator, Iterator, Sequence
from typing import ( from typing import (
@ -814,10 +815,7 @@ async def _astream_events_implementation_v1(
data: EventData = {} data: EventData = {}
log_entry: LogEntry = run_log.state["logs"][path] log_entry: LogEntry = run_log.state["logs"][path]
if log_entry["end_time"] is None: if log_entry["end_time"] is None:
if log_entry["streamed_output"]: event_type = "stream" if log_entry["streamed_output"] else "start"
event_type = "stream"
else:
event_type = "start"
else: else:
event_type = "end" event_type = "end"
@ -983,14 +981,15 @@ async def _astream_events_implementation_v2(
yield event yield event
continue continue
if event["run_id"] == first_event_run_id and event["event"].endswith( # If it's the end event corresponding to the root runnable
"_end" # we dont include the input in the event since it's guaranteed
# to be included in the first event.
if (
event["run_id"] == first_event_run_id
and event["event"].endswith("_end")
and "input" in event["data"]
): ):
# If it's the end event corresponding to the root runnable del event["data"]["input"]
# we dont include the input in the event since it's guaranteed
# to be included in the first event.
if "input" in event["data"]:
del event["data"]["input"]
yield event yield event
except asyncio.CancelledError as exc: except asyncio.CancelledError as exc:
@ -1001,7 +1000,5 @@ async def _astream_events_implementation_v2(
# Cancel the task if it's still running # Cancel the task if it's still running
task.cancel() task.cancel()
# Await it anyway, to run any cleanup code, and propagate any exceptions # Await it anyway, to run any cleanup code, and propagate any exceptions
try: with contextlib.suppress(asyncio.CancelledError):
await task await task
except asyncio.CancelledError:
pass

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import copy import copy
import threading import threading
from collections import defaultdict from collections import defaultdict
@ -253,18 +254,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
""" """
async for chunk in output: async for chunk in output:
# root run is handled in .astream_log() # root run is handled in .astream_log()
if run_id != self.root_id: # if we can't find the run silently ignore
# if we can't find the run silently ignore # eg. because this run wasn't included in the log
# eg. because this run wasn't included in the log if (
if key := self._key_map_by_run_id.get(run_id): run_id != self.root_id
if not self.send( and (key := self._key_map_by_run_id.get(run_id))
and (
not self.send(
{ {
"op": "add", "op": "add",
"path": f"/logs/{key}/streamed_output/-", "path": f"/logs/{key}/streamed_output/-",
"value": chunk, "value": chunk,
} }
): )
break )
):
break
yield chunk yield chunk
@ -280,18 +285,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
""" """
for chunk in output: for chunk in output:
# root run is handled in .astream_log() # root run is handled in .astream_log()
if run_id != self.root_id: # if we can't find the run silently ignore
# if we can't find the run silently ignore # eg. because this run wasn't included in the log
# eg. because this run wasn't included in the log if (
if key := self._key_map_by_run_id.get(run_id): run_id != self.root_id
if not self.send( and (key := self._key_map_by_run_id.get(run_id))
and (
not self.send(
{ {
"op": "add", "op": "add",
"path": f"/logs/{key}/streamed_output/-", "path": f"/logs/{key}/streamed_output/-",
"value": chunk, "value": chunk,
} }
): )
break )
):
break
yield chunk yield chunk
@ -439,9 +448,8 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler):
self.send(*ops) self.send(*ops)
finally: finally:
if run.id == self.root_id: if run.id == self.root_id and self.auto_close:
if self.auto_close: self.send_stream.close()
self.send_stream.close()
def _on_llm_new_token( def _on_llm_new_token(
self, self,
@ -662,7 +670,5 @@ async def _astream_log_implementation(
yield state yield state
finally: finally:
# Wait for the runnable to finish, if not cancelled (eg. by break) # Wait for the runnable to finish, if not cancelled (eg. by break)
try: with contextlib.suppress(asyncio.CancelledError):
await task await task
except asyncio.CancelledError:
pass

View File

@ -29,9 +29,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any]
merged = left.copy() merged = left.copy()
for right in others: for right in others:
for right_k, right_v in right.items(): for right_k, right_v in right.items():
if right_k not in merged: if right_k not in merged or right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
merged[right_k] = right_v merged[right_k] = right_v
elif right_v is None: elif right_v is None:
continue continue

View File

@ -43,14 +43,10 @@ def get_from_dict_or_env(
if k in data and data[k]: if k in data and data[k]:
return data[k] return data[k]
if isinstance(key, str): if isinstance(key, str) and key in data and data[key]:
if key in data and data[key]: return data[key]
return data[key]
if isinstance(key, (list, tuple)): key_for_err = key[0] if isinstance(key, (list, tuple)) else key
key_for_err = key[0]
else:
key_for_err = key
return get_from_env(key_for_err, env_key, default=default) return get_from_env(key_for_err, env_key, default=default)

View File

@ -64,7 +64,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict:
new_kv = {} new_kv = {}
for k, v in kv.items(): for k, v in kv.items():
if k == "title": if k == "title":
if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys(): if isinstance(v, dict) and prev_key == "properties" and "title" in v:
new_kv[k] = _rm_titles(v, k) new_kv[k] = _rm_titles(v, k)
else: else:
continue continue

View File

@ -139,11 +139,8 @@ def parse_json_markdown(
match = _json_markdown_re.search(json_string) match = _json_markdown_re.search(json_string)
# If no match found, assume the entire string is a JSON string # If no match found, assume the entire string is a JSON string
if match is None: # Else, use the content within the backticks
json_str = json_string json_str = json_string if match is None else match.group(2)
else:
# If match found, use the content within the backticks
json_str = match.group(2)
return _parse_json(json_str, parser=parser) return _parse_json(json_str, parser=parser)

View File

@ -80,12 +80,9 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
padding = literal.split("\n")[-1] padding = literal.split("\n")[-1]
# If all the characters since the last newline are spaces # If all the characters since the last newline are spaces
if padding.isspace() or padding == "": # Then the next tag could be a standalone
# Then the next tag could be a standalone # Otherwise it can't be
return True return padding.isspace() or padding == ""
else:
# Otherwise it can't be
return False
else: else:
return False return False
@ -107,10 +104,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
on_newline = template.split("\n", 1) on_newline = template.split("\n", 1)
# If the stuff to the right of us are spaces we're a standalone # If the stuff to the right of us are spaces we're a standalone
if on_newline[0].isspace() or not on_newline[0]: return on_newline[0].isspace() or not on_newline[0]
return True
else:
return False
# If we're a tag can't be a standalone # If we're a tag can't be a standalone
else: else:
@ -174,14 +168,18 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s
"unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}" "unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}"
) )
# If we might be a no html escape tag elif (
elif tag_type == "no escape?": # If we might be a no html escape tag
tag_type == "no escape?"
# And we have a third curly brace # And we have a third curly brace
# (And are using curly braces as delimiters) # (And are using curly braces as delimiters)
if l_del == "{{" and r_del == "}}" and template.startswith("}"): and l_del == "{{"
# Then we are a no html escape tag and r_del == "}}"
template = template[1:] and template.startswith("}")
tag_type = "no escape" ):
# Then we are a no html escape tag
template = template[1:]
tag_type = "no escape"
# Strip the whitespace off the key and return # Strip the whitespace off the key and return
return ((tag_type, tag.strip()), template) return ((tag_type, tag.strip()), template)

View File

@ -179,22 +179,27 @@ def pre_init(func: Callable) -> Any:
for name, field_info in fields.items(): for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled # Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias # If yes, then set the field name to the alias
if hasattr(cls, "Config"): if (
if hasattr(cls.Config, "allow_population_by_field_name"): hasattr(cls, "Config")
if cls.Config.allow_population_by_field_name: and hasattr(cls.Config, "allow_population_by_field_name")
if field_info.alias in values: and cls.Config.allow_population_by_field_name
values[name] = values.pop(field_info.alias) and field_info.alias in values
if hasattr(cls, "model_config"): ):
if cls.model_config.get("populate_by_name"): values[name] = values.pop(field_info.alias)
if field_info.alias in values: if (
values[name] = values.pop(field_info.alias) hasattr(cls, "model_config")
and cls.model_config.get("populate_by_name")
and field_info.alias in values
):
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None: if (
if not field_info.is_required(): name not in values or values[name] is None
if field_info.default_factory is not None: ) and not field_info.is_required():
values[name] = field_info.default_factory() if field_info.default_factory is not None:
else: values[name] = field_info.default_factory()
values[name] = field_info.default else:
values[name] = field_info.default
# Call the decorated function # Call the decorated function
return func(cls, values) return func(cls, values)

View File

@ -332,9 +332,8 @@ def from_env(
for k in key: for k in key:
if k in os.environ: if k in os.environ:
return os.environ[k] return os.environ[k]
if isinstance(key, str): if isinstance(key, str) and key in os.environ:
if key in os.environ: return os.environ[key]
return os.environ[key]
if isinstance(default, (str, type(None))): if isinstance(default, (str, type(None))):
return default return default
@ -395,9 +394,8 @@ def secret_from_env(
for k in key: for k in key:
if k in os.environ: if k in os.environ:
return SecretStr(os.environ[k]) return SecretStr(os.environ[k])
if isinstance(key, str): if isinstance(key, str) and key in os.environ:
if key in os.environ: return SecretStr(os.environ[key])
return SecretStr(os.environ[key])
if isinstance(default, str): if isinstance(default, str):
return SecretStr(default) return SecretStr(default)
elif isinstance(default, type(None)): elif isinstance(default, type(None)):

View File

@ -44,7 +44,7 @@ python = ">=3.12.4"
[tool.poetry.extras] [tool.poetry.extras]
[tool.ruff.lint] [tool.ruff.lint]
select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",] select = [ "B", "C4", "E", "F", "I", "N", "PIE", "SIM", "T201", "UP",]
ignore = [ "UP007",] ignore = [ "UP007",]
[tool.coverage.run] [tool.coverage.run]

View File

@ -148,7 +148,7 @@ def test_pydantic_output_parser() -> None:
result = pydantic_parser.parse(DEF_RESULT) result = pydantic_parser.parse(DEF_RESULT)
print("parse_result:", result) # noqa: T201 print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == result assert result == DEF_EXPECTED_RESULT
assert pydantic_parser.OutputType is TestModel assert pydantic_parser.OutputType is TestModel

View File

@ -115,10 +115,7 @@ def _normalize_schema(obj: Any) -> dict[str, Any]:
Args: Args:
obj: The object to generate the schema for obj: The object to generate the schema for
""" """
if isinstance(obj, BaseModel): data = obj.model_json_schema() if isinstance(obj, BaseModel) else obj
data = obj.model_json_schema()
else:
data = obj
remove_all_none_default(data) remove_all_none_default(data)
replace_all_of_with_ref(data) replace_all_of_with_ref(data)
_remove_enum(data) _remove_enum(data)

View File

@ -3345,9 +3345,9 @@ def test_bind_with_lambda() -> None:
return 3 + kwargs.get("n", 0) return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1) runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == runnable.invoke({}) assert runnable.invoke({}) == 4
chunks = list(runnable.stream({})) chunks = list(runnable.stream({}))
assert [4] == chunks assert chunks == [4]
async def test_bind_with_lambda_async() -> None: async def test_bind_with_lambda_async() -> None:
@ -3355,9 +3355,9 @@ async def test_bind_with_lambda_async() -> None:
return 3 + kwargs.get("n", 0) return 3 + kwargs.get("n", 0)
runnable = RunnableLambda(my_function).bind(n=1) runnable = RunnableLambda(my_function).bind(n=1)
assert 4 == await runnable.ainvoke({}) assert await runnable.ainvoke({}) == 4
chunks = [item async for item in runnable.astream({})] chunks = [item async for item in runnable.astream({})]
assert [4] == chunks assert chunks == [4]
def test_deep_stream() -> None: def test_deep_stream() -> None:
@ -5140,13 +5140,10 @@ async def test_astream_log_deep_copies() -> None:
chain = RunnableLambda(add_one) chain = RunnableLambda(add_one)
chunks = [] chunks = []
final_output = None final_output: Optional[RunLogPatch] = None
async for chunk in chain.astream_log(1): async for chunk in chain.astream_log(1):
chunks.append(chunk) chunks.append(chunk)
if final_output is None: final_output = chunk if final_output is None else final_output + chunk
final_output = chunk
else:
final_output = final_output + chunk
run_log = _get_run_log(chunks) run_log = _get_run_log(chunks)
state = run_log.state.copy() state = run_log.state.copy()

View File

@ -209,9 +209,11 @@ def test_tracing_enable_disable(
get_env_var.cache_clear() get_env_var.cache_clear()
env_on = env == "true" env_on = env == "true"
with patch.dict("os.environ", {"LANGSMITH_TRACING": env}): with (
with tracing_context(enabled=enabled): patch.dict("os.environ", {"LANGSMITH_TRACING": env}),
RunnableLambda(my_func).invoke(1) tracing_context(enabled=enabled),
):
RunnableLambda(my_func).invoke(1)
mock_posts = _get_posts(mock_client_) mock_posts = _get_posts(mock_client_)
if enabled is True: if enabled is True:

View File

@ -1,6 +1,6 @@
import unittest import unittest
import uuid import uuid
from typing import Union from typing import Optional, Union
import pytest import pytest
@ -10,6 +10,7 @@ from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk, ChatMessageChunk,
FunctionMessage, FunctionMessage,
@ -630,14 +631,11 @@ def test_tool_calls_merge() -> None:
{"content": ""}, {"content": ""},
] ]
final = None final: Optional[BaseMessageChunk] = None
for chunk in chunks: for chunk in chunks:
msg = AIMessageChunk(**chunk) msg = AIMessageChunk(**chunk)
if final is None: final = msg if final is None else final + msg
final = msg
else:
final = final + msg
assert final == AIMessageChunk( assert final == AIMessageChunk(
content="", content="",

View File

@ -133,25 +133,24 @@ class LangChainProjectNameTest(unittest.TestCase):
for case in cases: for case in cases:
get_env_var.cache_clear() get_env_var.cache_clear()
get_tracer_project.cache_clear() get_tracer_project.cache_clear()
with self.subTest(msg=case.test_name): with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
with pytest.MonkeyPatch.context() as mp: for k, v in case.envvars.items():
for k, v in case.envvars.items(): mp.setenv(k, v)
mp.setenv(k, v)
client = unittest.mock.MagicMock(spec=Client) client = unittest.mock.MagicMock(spec=Client)
tracer = LangChainTracer(client=client) tracer = LangChainTracer(client=client)
projects = [] projects = []
def mock_create_run(**kwargs: Any) -> Any: def mock_create_run(**kwargs: Any) -> Any:
projects.append(kwargs.get("project_name")) # noqa: B023 projects.append(kwargs.get("project_name")) # noqa: B023
return unittest.mock.MagicMock() return unittest.mock.MagicMock()
client.create_run = mock_create_run client.create_run = mock_create_run
tracer.on_llm_start( tracer.on_llm_start(
{"name": "example_1"}, {"name": "example_1"},
["foo"], ["foo"],
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"), run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
) )
tracer.wait_for_futures() tracer.wait_for_futures()
assert projects == [case.expected_project_name] assert projects == [case.expected_project_name]

View File

@ -513,14 +513,8 @@ def test_tool_outputs() -> None:
def test__convert_typed_dict_to_openai_function( def test__convert_typed_dict_to_openai_function(
use_extension_typed_dict: bool, use_extension_annotated: bool use_extension_typed_dict: bool, use_extension_annotated: bool
) -> None: ) -> None:
if use_extension_typed_dict: typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict
typed_dict = ExtensionsTypedDict annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated
else:
typed_dict = TypingTypedDict
if use_extension_annotated:
annotated = TypingAnnotated
else:
annotated = TypingAnnotated
class SubTool(typed_dict): class SubTool(typed_dict):
"""Subtool docstring""" """Subtool docstring"""

View File

@ -116,10 +116,7 @@ def test_check_package_version(
def test_merge_dicts( def test_merge_dicts(
left: dict, right: dict, expected: Union[dict, AbstractContextManager] left: dict, right: dict, expected: Union[dict, AbstractContextManager]
) -> None: ) -> None:
if isinstance(expected, AbstractContextManager): err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
err = expected
else:
err = nullcontext()
left_copy = deepcopy(left) left_copy = deepcopy(left)
right_copy = deepcopy(right) right_copy = deepcopy(right)
@ -147,10 +144,7 @@ def test_merge_dicts(
def test_merge_dicts_0_3( def test_merge_dicts_0_3(
left: dict, right: dict, expected: Union[dict, AbstractContextManager] left: dict, right: dict, expected: Union[dict, AbstractContextManager]
) -> None: ) -> None:
if isinstance(expected, AbstractContextManager): err = expected if isinstance(expected, AbstractContextManager) else nullcontext()
err = expected
else:
err = nullcontext()
left_copy = deepcopy(left) left_copy = deepcopy(left)
right_copy = deepcopy(right) right_copy = deepcopy(right)