diff --git a/libs/core/langchain_core/_api/beta_decorator.py b/libs/core/langchain_core/_api/beta_decorator.py index 46514760549..4d6810dbec3 100644 --- a/libs/core/langchain_core/_api/beta_decorator.py +++ b/libs/core/langchain_core/_api/beta_decorator.py @@ -127,10 +127,9 @@ def beta( def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: """Finalize the annotation of a class.""" - try: + # Can't set new_doc on some extension objects. + with contextlib.suppress(AttributeError): obj.__doc__ = new_doc - except AttributeError: # Can't set on some extension objects. - pass def warn_if_direct_instance( self: Any, *args: Any, **kwargs: Any diff --git a/libs/core/langchain_core/_api/deprecation.py b/libs/core/langchain_core/_api/deprecation.py index 5153e8acdac..58a97a416af 100644 --- a/libs/core/langchain_core/_api/deprecation.py +++ b/libs/core/langchain_core/_api/deprecation.py @@ -198,10 +198,9 @@ def deprecated( def finalize(wrapper: Callable[..., Any], new_doc: str) -> T: """Finalize the deprecation of a class.""" - try: + # Can't set new_doc on some extension objects. + with contextlib.suppress(AttributeError): obj.__doc__ = new_doc - except AttributeError: # Can't set on some extension objects. - pass def warn_if_direct_instance( self: Any, *args: Any, **kwargs: Any diff --git a/libs/core/langchain_core/callbacks/file.py b/libs/core/langchain_core/callbacks/file.py index f3f695ccfa7..7ea1ff76f8a 100644 --- a/libs/core/langchain_core/callbacks/file.py +++ b/libs/core/langchain_core/callbacks/file.py @@ -27,7 +27,7 @@ class FileCallbackHandler(BaseCallbackHandler): mode: The mode to open the file in. Defaults to "a". color: The color to use for the text. Defaults to None. """ - self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) + self.file = cast(TextIO, open(filename, mode, encoding="utf-8")) # noqa: SIM115 self.color = color def __del__(self) -> None: diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index c3214ea1123..05dd786591d 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -2252,14 +2252,14 @@ def _configure( else: parent_run_id_ = inheritable_callbacks.parent_run_id # Break ties between the external tracing context and inherited context - if parent_run_id is not None: - if parent_run_id_ is None: - parent_run_id_ = parent_run_id + if parent_run_id is not None and ( + parent_run_id_ is None # If the LC parent has already been reflected # in the run tree, we know the run_tree is either the # same parent or a child of the parent. - elif run_tree and str(parent_run_id_) in run_tree.dotted_order: - parent_run_id_ = parent_run_id + or (run_tree and str(parent_run_id_) in run_tree.dotted_order) + ): + parent_run_id_ = parent_run_id # Otherwise, we assume the LC context has progressed # beyond the run tree and we should not inherit the parent. callback_manager = callback_manager_cls( diff --git a/libs/core/langchain_core/exceptions.py b/libs/core/langchain_core/exceptions.py index 7c60ccfa4db..c7ab7419f2a 100644 --- a/libs/core/langchain_core/exceptions.py +++ b/libs/core/langchain_core/exceptions.py @@ -40,12 +40,11 @@ class OutputParserException(ValueError, LangChainException): # noqa: N818 send_to_llm: bool = False, ): super().__init__(error) - if send_to_llm: - if observation is None or llm_output is None: - raise ValueError( - "Arguments 'observation' & 'llm_output'" - " are required if 'send_to_llm' is True" - ) + if send_to_llm and (observation is None or llm_output is None): + raise ValueError( + "Arguments 'observation' & 'llm_output'" + " are required if 'send_to_llm' is True" + ) self.observation = observation self.llm_output = llm_output self.send_to_llm = send_to_llm diff --git a/libs/core/langchain_core/indexing/api.py b/libs/core/langchain_core/indexing/api.py index 388feb5ef02..87baed948c7 100644 --- a/libs/core/langchain_core/indexing/api.py +++ b/libs/core/langchain_core/indexing/api.py @@ -92,7 +92,7 @@ class _HashedDocument(Document): values["metadata_hash"] = metadata_hash values["hash_"] = str(_hash_string_to_uuid(content_hash + metadata_hash)) - _uid = values.get("uid", None) + _uid = values.get("uid") if _uid is None: values["uid"] = values["hash_"] diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 7d46f490bbb..8b19a510d2a 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -802,10 +802,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if isinstance(self.cache, BaseCache): - llm_cache = self.cache - else: - llm_cache = get_llm_cache() + llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache() # We should check the cache unless it's explicitly set to False # A None cache means we should use the default global cache # if it's configured. @@ -879,10 +876,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - if isinstance(self.cache, BaseCache): - llm_cache = self.cache - else: - llm_cache = get_llm_cache() + llm_cache = self.cache if isinstance(self.cache, BaseCache) else get_llm_cache() # We should check the cache unless it's explicitly set to False # A None cache means we should use the default global cache # if it's configured. @@ -1054,10 +1048,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) result = self([HumanMessage(content=text)], stop=_stop, **kwargs) if isinstance(result.content, str): return result.content @@ -1072,20 +1063,14 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[Sequence[str]] = None, **kwargs: Any, ) -> BaseMessage: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) return self(messages, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="ainvoke", removal="1.0") async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) result = await self._call_async( [HumanMessage(content=text)], stop=_stop, **kwargs ) @@ -1102,10 +1087,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): stop: Optional[Sequence[str]] = None, **kwargs: Any, ) -> BaseMessage: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) return await self._call_async(messages, stop=_stop, **kwargs) @property @@ -1333,9 +1315,12 @@ def _cleanup_llm_representation(serialized: Any, depth: int) -> None: if not isinstance(serialized, dict): return - if "type" in serialized and serialized["type"] == "not_implemented": - if "repr" in serialized: - del serialized["repr"] + if ( + "type" in serialized + and serialized["type"] == "not_implemented" + and "repr" in serialized + ): + del serialized["repr"] if "graph" in serialized: del serialized["graph"] diff --git a/libs/core/langchain_core/language_models/fake_chat_models.py b/libs/core/langchain_core/language_models/fake_chat_models.py index 3c41c1d462f..6476db5f825 100644 --- a/libs/core/langchain_core/language_models/fake_chat_models.py +++ b/libs/core/langchain_core/language_models/fake_chat_models.py @@ -194,10 +194,7 @@ class GenericFakeChatModel(BaseChatModel): ) -> ChatResult: """Top Level call""" message = next(self.messages) - if isinstance(message, str): - message_ = AIMessage(content=message) - else: - message_ = message + message_ = AIMessage(content=message) if isinstance(message, str) else message generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index af2b5ef0d8b..f4ba7791eaf 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -1305,10 +1305,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) return self(text, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="invoke", removal="1.0") @@ -1320,10 +1317,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) content = self(text, stop=_stop, **kwargs) return AIMessage(content=content) @@ -1331,10 +1325,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) return await self._call_async(text, stop=_stop, **kwargs) @deprecated("0.1.7", alternative="ainvoke", removal="1.0") @@ -1346,10 +1337,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) - if stop is None: - _stop = None - else: - _stop = list(stop) + _stop = None if stop is None else list(stop) content = await self._call_async(text, stop=_stop, **kwargs) return AIMessage(content=content) @@ -1384,10 +1372,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): llm.save(file_path="path/llm.yaml") """ # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path + save_path = Path(file_path) if isinstance(file_path, str) else file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) diff --git a/libs/core/langchain_core/load/load.py b/libs/core/langchain_core/load/load.py index d050d4696c9..875ddf92f90 100644 --- a/libs/core/langchain_core/load/load.py +++ b/libs/core/langchain_core/load/load.py @@ -86,9 +86,9 @@ class Reviver: def __call__(self, value: dict[str, Any]) -> Any: if ( - value.get("lc", None) == 1 - and value.get("type", None) == "secret" - and value.get("id", None) is not None + value.get("lc") == 1 + and value.get("type") == "secret" + and value.get("id") is not None ): [key] = value["id"] if key in self.secrets_map: @@ -99,9 +99,9 @@ class Reviver: raise KeyError(f'Missing key "{key}" in load(secrets_map)') if ( - value.get("lc", None) == 1 - and value.get("type", None) == "not_implemented" - and value.get("id", None) is not None + value.get("lc") == 1 + and value.get("type") == "not_implemented" + and value.get("id") is not None ): raise NotImplementedError( "Trying to load an object that doesn't implement " @@ -109,17 +109,18 @@ class Reviver: ) if ( - value.get("lc", None) == 1 - and value.get("type", None) == "constructor" - and value.get("id", None) is not None + value.get("lc") == 1 + and value.get("type") == "constructor" + and value.get("id") is not None ): [*namespace, name] = value["id"] mapping_key = tuple(value["id"]) - if namespace[0] not in self.valid_namespaces: - raise ValueError(f"Invalid namespace: {value}") - # The root namespace ["langchain"] is not a valid identifier. - elif namespace == ["langchain"]: + if ( + namespace[0] not in self.valid_namespaces + # The root namespace ["langchain"] is not a valid identifier. + or namespace == ["langchain"] + ): raise ValueError(f"Invalid namespace: {value}") # Has explicit import path. elif mapping_key in self.import_mappings: diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index 9158c1e5b8b..02d410adb85 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -1,3 +1,4 @@ +import contextlib from abc import ABC from typing import ( Any, @@ -238,7 +239,7 @@ class Serializable(BaseModel, ABC): # include all secrets, even if not specified in kwargs # as these secrets may be passed as an environment variable instead - for key in secrets.keys(): + for key in secrets: secret_value = getattr(self, key, None) or lc_kwargs.get(key) if secret_value is not None: lc_kwargs.update({key: secret_value}) @@ -357,8 +358,6 @@ def to_json_not_implemented(obj: object) -> SerializedNotImplemented: "id": _id, "repr": None, } - try: + with contextlib.suppress(Exception): result["repr"] = repr(obj) - except Exception: - pass return result diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 05795e6c438..d6092530831 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -435,23 +435,22 @@ def filter_messages( messages = convert_to_messages(messages) filtered: list[BaseMessage] = [] for msg in messages: - if exclude_names and msg.name in exclude_names: - continue - elif exclude_types and _is_message_type(msg, exclude_types): - continue - elif exclude_ids and msg.id in exclude_ids: + if ( + (exclude_names and msg.name in exclude_names) + or (exclude_types and _is_message_type(msg, exclude_types)) + or (exclude_ids and msg.id in exclude_ids) + ): continue else: pass # default to inclusion when no inclusion criteria given. - if not (include_types or include_ids or include_names): - filtered.append(msg) - elif include_names and msg.name in include_names: - filtered.append(msg) - elif include_types and _is_message_type(msg, include_types): - filtered.append(msg) - elif include_ids and msg.id in include_ids: + if ( + not (include_types or include_ids or include_names) + or (include_names and msg.name in include_names) + or (include_types and _is_message_type(msg, include_types)) + or (include_ids and msg.id in include_ids) + ): filtered.append(msg) else: pass @@ -961,10 +960,7 @@ def _last_max_tokens( while messages and not _is_message_type(messages[-1], end_on): messages.pop() swapped_system = include_system and isinstance(messages[0], SystemMessage) - if swapped_system: - reversed_ = messages[:1] + messages[1:][::-1] - else: - reversed_ = messages[::-1] + reversed_ = messages[:1] + messages[1:][::-1] if swapped_system else messages[::-1] reversed_ = _first_max_tokens( reversed_, diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 16464806dd3..bbecc94ee38 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, @@ -311,8 +312,6 @@ class BaseOutputParser( def dict(self, **kwargs: Any) -> dict: """Return dictionary representation of output parser.""" output_parser_dict = super().dict(**kwargs) - try: + with contextlib.suppress(NotImplementedError): output_parser_dict["_type"] = self._type - except NotImplementedError: - pass return output_parser_dict diff --git a/libs/core/langchain_core/output_parsers/json.py b/libs/core/langchain_core/output_parsers/json.py index ff174767f1c..5bc43048a47 100644 --- a/libs/core/langchain_core/output_parsers/json.py +++ b/libs/core/langchain_core/output_parsers/json.py @@ -49,12 +49,10 @@ class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]): return jsonpatch.make_patch(prev, next).patch def _get_schema(self, pydantic_object: type[TBaseModel]) -> dict[str, Any]: - if PYDANTIC_MAJOR_VERSION == 2: - if issubclass(pydantic_object, pydantic.BaseModel): - return pydantic_object.model_json_schema() - elif issubclass(pydantic_object, pydantic.v1.BaseModel): - return pydantic_object.model_json_schema() - return pydantic_object.model_json_schema() + if issubclass(pydantic_object, pydantic.BaseModel): + return pydantic_object.model_json_schema() + elif issubclass(pydantic_object, pydantic.v1.BaseModel): + return pydantic_object.schema() def parse_result(self, result: list[Generation], *, partial: bool = False) -> Any: """Parse the result of an LLM call to a JSON object. diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index be292738710..8367681d56b 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -110,10 +110,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: prev_parsed = None - acc_gen = None + acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None for chunk in input: + chunk_gen: Union[GenerationChunk, ChatGenerationChunk] if isinstance(chunk, BaseMessageChunk): - chunk_gen: Generation = ChatGenerationChunk(message=chunk) + chunk_gen = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -121,10 +122,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: chunk_gen = GenerationChunk(text=chunk) - if acc_gen is None: - acc_gen = chunk_gen - else: - acc_gen = acc_gen + chunk_gen + acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator] parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: @@ -138,10 +136,11 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[T]: prev_parsed = None - acc_gen = None + acc_gen: Union[GenerationChunk, ChatGenerationChunk, None] = None async for chunk in input: + chunk_gen: Union[GenerationChunk, ChatGenerationChunk] if isinstance(chunk, BaseMessageChunk): - chunk_gen: Generation = ChatGenerationChunk(message=chunk) + chunk_gen = ChatGenerationChunk(message=chunk) elif isinstance(chunk, BaseMessage): chunk_gen = ChatGenerationChunk( message=BaseMessageChunk(**chunk.dict()) @@ -149,10 +148,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): else: chunk_gen = GenerationChunk(text=chunk) - if acc_gen is None: - acc_gen = chunk_gen - else: - acc_gen = acc_gen + chunk_gen + acc_gen = chunk_gen if acc_gen is None else acc_gen + chunk_gen # type: ignore[operator] parsed = await self.aparse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 3d55bed32d6..d2b14f05d06 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -1,3 +1,4 @@ +import contextlib import re import xml import xml.etree.ElementTree as ET # noqa: N817 @@ -131,11 +132,9 @@ class _StreamingParser: Raises: xml.etree.ElementTree.ParseError: If the XML is not well-formed. """ - try: + # Ignore ParseError. This will ignore any incomplete XML at the end of the input + with contextlib.suppress(xml.etree.ElementTree.ParseError): self.pull_parser.close() - except xml.etree.ElementTree.ParseError: - # Ignore. This will ignore any incomplete XML at the end of the input - pass class XMLOutputParser(BaseTransformOutputParser): diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index f98d87d780c..093b7938d47 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import json import typing from abc import ABC, abstractmethod @@ -319,10 +320,8 @@ class BasePromptTemplate( NotImplementedError: If the prompt type is not implemented. """ prompt_dict = super().model_dump(**kwargs) - try: + with contextlib.suppress(NotImplementedError): prompt_dict["_type"] = self._prompt_type - except NotImplementedError: - pass return prompt_dict def save(self, file_path: Union[Path, str]) -> None: @@ -350,10 +349,7 @@ class BasePromptTemplate( raise NotImplementedError(f"Prompt {self} does not support saving.") # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path + save_path = Path(file_path) if isinstance(file_path, str) else file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index ac20af37845..d02c885aa85 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -59,8 +59,8 @@ class _FewShotPromptTemplateMixin(BaseModel): ValueError: If neither or both examples and example_selector are provided. ValueError: If both examples and example_selector are provided. """ - examples = values.get("examples", None) - example_selector = values.get("example_selector", None) + examples = values.get("examples") + example_selector = values.get("example_selector") if examples and example_selector: raise ValueError( "Only one of 'examples' and 'example_selector' should be provided" diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index 65a417046c6..75e6344aa7a 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -51,8 +51,8 @@ class FewShotPromptWithTemplates(StringPromptTemplate): @classmethod def check_examples_and_selector(cls, values: dict) -> Any: """Check that one and only one of examples/example_selector are provided.""" - examples = values.get("examples", None) - example_selector = values.get("example_selector", None) + examples = values.get("examples") + example_selector = values.get("example_selector") if examples and example_selector: raise ValueError( "Only one of 'examples' and 'example_selector' should be provided" @@ -138,7 +138,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): prefix_kwargs = { k: v for k, v in kwargs.items() if k in self.prefix.input_variables } - for k in prefix_kwargs.keys(): + for k in prefix_kwargs: kwargs.pop(k) prefix = self.prefix.format(**prefix_kwargs) @@ -146,7 +146,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): suffix_kwargs = { k: v for k, v in kwargs.items() if k in self.suffix.input_variables } - for k in suffix_kwargs.keys(): + for k in suffix_kwargs: kwargs.pop(k) suffix = self.suffix.format( **suffix_kwargs, @@ -182,7 +182,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): prefix_kwargs = { k: v for k, v in kwargs.items() if k in self.prefix.input_variables } - for k in prefix_kwargs.keys(): + for k in prefix_kwargs: kwargs.pop(k) prefix = await self.prefix.aformat(**prefix_kwargs) @@ -190,7 +190,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): suffix_kwargs = { k: v for k, v in kwargs.items() if k in self.suffix.input_variables } - for k in suffix_kwargs.keys(): + for k in suffix_kwargs: kwargs.pop(k) suffix = await self.suffix.aformat( **suffix_kwargs, diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index a4a1843786b..39fce9fc301 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -164,10 +164,7 @@ def _load_prompt_from_file( ) -> BasePromptTemplate: """Load prompt from file.""" # Convert file to a Path object. - if isinstance(file, str): - file_path = Path(file) - else: - file_path = file + file_path = Path(file) if isinstance(file, str) else file # Load from either json or yaml. if file_path.suffix == ".json": with open(file_path, encoding=encoding) as f: diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 4aa8681a391..cab3c40f7c6 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import collections +import contextlib import functools import inspect import threading @@ -2466,10 +2467,8 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): A JSON-serializable representation of the Runnable. """ dumped = super().to_json() - try: + with contextlib.suppress(Exception): dumped["name"] = self.get_name() - except Exception: - pass return dumped def configurable_fields( @@ -2763,9 +2762,8 @@ class RunnableSequence(RunnableSerializable[Input, Output]): ValueError: If the sequence has less than 2 steps. """ steps_flat: list[Runnable] = [] - if not steps: - if first is not None and last is not None: - steps_flat = [first] + (middle or []) + [last] + if not steps and first is not None and last is not None: + steps_flat = [first] + (middle or []) + [last] for step in steps: if isinstance(step, RunnableSequence): steps_flat.extend(step.steps) @@ -4180,12 +4178,9 @@ class RunnableGenerator(Runnable[Input, Output]): def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - final = None + final: Optional[Output] = None for output in self.stream(input, config, **kwargs): - if final is None: - final = output - else: - final = final + output + final = output if final is None else final + output # type: ignore[operator] return cast(Output, final) def atransform( @@ -4215,12 +4210,9 @@ class RunnableGenerator(Runnable[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - final = None + final: Optional[Output] = None async for output in self.astream(input, config, **kwargs): - if final is None: - final = output - else: - final = final + output + final = output if final is None else final + output # type: ignore[operator] return cast(Output, final) diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index b1d40bd5510..765bf79a2dd 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -139,11 +139,11 @@ def _set_config_context(config: RunnableConfig) -> None: None, ) ) + and (run := tracer.run_map.get(str(parent_run_id))) ): - if run := tracer.run_map.get(str(parent_run_id)): - from langsmith.run_helpers import _set_tracing_context + from langsmith.run_helpers import _set_tracing_context - _set_tracing_context({"parent": run}) + _set_tracing_context({"parent": run}) def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: diff --git a/libs/core/langchain_core/runnables/graph_ascii.py b/libs/core/langchain_core/runnables/graph_ascii.py index 91a647f20e2..bdd1b5e2e4d 100644 --- a/libs/core/langchain_core/runnables/graph_ascii.py +++ b/libs/core/langchain_core/runnables/graph_ascii.py @@ -99,24 +99,15 @@ class AsciiCanvas: self.point(x0, y0, char) elif abs(dx) >= abs(dy): for x in range(x0, x1 + 1): - if dx == 0: - y = y0 - else: - y = y0 + int(round((x - x0) * dy / float(dx))) + y = y0 if dx == 0 else y0 + int(round((x - x0) * dy / float(dx))) self.point(x, y, char) elif y0 < y1: for y in range(y0, y1 + 1): - if dy == 0: - x = x0 - else: - x = x0 + int(round((y - y0) * dx / float(dy))) + x = x0 if dy == 0 else x0 + int(round((y - y0) * dx / float(dy))) self.point(x, y, char) else: for y in range(y1, y0 + 1): - if dy == 0: - x = x0 - else: - x = x1 + int(round((y - y1) * dx / float(dy))) + x = x0 if dy == 0 else x1 + int(round((y - y1) * dx / float(dy))) self.point(x, y, char) def text(self, x: int, y: int, text: str) -> None: diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 9eb83b6a537..175c53eb6a4 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -131,10 +131,7 @@ def draw_mermaid( else: edge_label = f" --  {edge_data}  --> " else: - if edge.conditional: - edge_label = " -.-> " - else: - edge_label = " --> " + edge_label = " -.-> " if edge.conditional else " --> " mermaid_graph += ( f"\t{_escape_node_label(source)}{edge_label}" @@ -142,7 +139,7 @@ def draw_mermaid( ) # Recursively add nested subgraphs - for nested_prefix in edge_groups.keys(): + for nested_prefix in edge_groups: if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: continue add_subgraph(edge_groups[nested_prefix], nested_prefix) @@ -154,7 +151,7 @@ def draw_mermaid( add_subgraph(edge_groups.get("", []), "") # Add remaining subgraphs - for prefix in edge_groups.keys(): + for prefix in edge_groups: if ":" in prefix or prefix == "": continue add_subgraph(edge_groups[prefix], prefix) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index e464c971b71..189207ff2e4 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -496,12 +496,9 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]: Returns: Optional[Addable]: The result of adding the addable objects. """ - final = None + final: Optional[Addable] = None for chunk in addables: - if final is None: - final = chunk - else: - final = final + chunk + final = chunk if final is None else final + chunk return final @@ -514,12 +511,9 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: Returns: Optional[Addable]: The result of adding the addable objects. """ - final = None + final: Optional[Addable] = None async for chunk in addables: - if final is None: - final = chunk - else: - final = final + chunk + final = chunk if final is None else final + chunk return final @@ -642,9 +636,7 @@ def get_unique_config_specs( for id, dupes in grouped: first = next(dupes) others = list(dupes) - if len(others) == 0: - unique.append(first) - elif all(o == first for o in others): + if len(others) == 0 or all(o == first for o in others): unique.append(first) else: raise ValueError( diff --git a/libs/core/langchain_core/stores.py b/libs/core/langchain_core/stores.py index 127d3c2e04a..4de878af615 100644 --- a/libs/core/langchain_core/stores.py +++ b/libs/core/langchain_core/stores.py @@ -258,7 +258,7 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): if prefix is None: yield from self.store.keys() else: - for key in self.store.keys(): + for key in self.store: if key.startswith(prefix): yield key @@ -272,10 +272,10 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]): AsyncIterator[str]: An async iterator over keys that match the given prefix. """ if prefix is None: - for key in self.store.keys(): + for key in self.store: yield key else: - for key in self.store.keys(): + for key in self.store: if key.startswith(prefix): yield key diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index 4214d367e40..9c46278b77c 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -19,18 +19,24 @@ class Visitor(ABC): """Allowed operators for the visitor.""" def _validate_func(self, func: Union[Operator, Comparator]) -> None: - if isinstance(func, Operator) and self.allowed_operators is not None: - if func not in self.allowed_operators: - raise ValueError( - f"Received disallowed operator {func}. Allowed " - f"comparators are {self.allowed_operators}" - ) - if isinstance(func, Comparator) and self.allowed_comparators is not None: - if func not in self.allowed_comparators: - raise ValueError( - f"Received disallowed comparator {func}. Allowed " - f"comparators are {self.allowed_comparators}" - ) + if ( + isinstance(func, Operator) + and self.allowed_operators is not None + and func not in self.allowed_operators + ): + raise ValueError( + f"Received disallowed operator {func}. Allowed " + f"comparators are {self.allowed_operators}" + ) + if ( + isinstance(func, Comparator) + and self.allowed_comparators is not None + and func not in self.allowed_comparators + ): + raise ValueError( + f"Received disallowed comparator {func}. Allowed " + f"comparators are {self.allowed_comparators}" + ) @abstractmethod def visit_operation(self, operation: Operation) -> Any: diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index d8f63c069eb..98c45dd5bb9 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -248,11 +248,8 @@ def create_schema_from_function( validated = validate_arguments(func, config=_SchemaConfig) # type: ignore # Let's ignore `self` and `cls` arguments for class and instance methods - if func.__qualname__ and "." in func.__qualname__: - # Then it likely belongs in a class namespace - in_class = True - else: - in_class = False + # If qualified name has a ".", then it likely belongs in a class namespace + in_class = bool(func.__qualname__ and "." in func.__qualname__) has_args = False has_kwargs = False @@ -289,12 +286,10 @@ def create_schema_from_function( # Pydantic adds placeholder virtual fields we need to strip valid_properties = [] for field in get_fields(inferred_model): - if not has_args: - if field == "args": - continue - if not has_kwargs: - if field == "kwargs": - continue + if not has_args and field == "args": + continue + if not has_kwargs and field == "kwargs": + continue if field == "v__duplicate_kwargs": # Internal pydantic field continue @@ -422,12 +417,15 @@ class ChildTool(BaseTool): def __init__(self, **kwargs: Any) -> None: """Initialize the tool.""" - if "args_schema" in kwargs and kwargs["args_schema"] is not None: - if not is_basemodel_subclass(kwargs["args_schema"]): - raise TypeError( - f"args_schema must be a subclass of pydantic BaseModel. " - f"Got: {kwargs['args_schema']}." - ) + if ( + "args_schema" in kwargs + and kwargs["args_schema"] is not None + and not is_basemodel_subclass(kwargs["args_schema"]) + ): + raise TypeError( + f"args_schema must be a subclass of pydantic BaseModel. " + f"Got: {kwargs['args_schema']}." + ) super().__init__(**kwargs) model_config = ConfigDict( @@ -840,10 +838,7 @@ def _handle_tool_error( flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], ) -> str: if isinstance(flag, bool): - if e.args: - content = e.args[0] - else: - content = "Tool execution error" + content = e.args[0] if e.args else "Tool execution error" elif isinstance(flag, str): content = flag elif callable(flag): @@ -902,12 +897,11 @@ def _format_output( def _is_message_content_type(obj: Any) -> bool: """Check for OpenAI or Anthropic format tool message content.""" - if isinstance(obj, str): - return True - elif isinstance(obj, list) and all(_is_message_content_block(e) for e in obj): - return True - else: - return False + return ( + isinstance(obj, str) + or isinstance(obj, list) + and all(_is_message_content_block(e) for e in obj) + ) def _is_message_content_block(obj: Any) -> bool: diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 03d7bc95616..57cd0dc5f42 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextlib import logging from collections.abc import AsyncIterator, Iterator, Sequence from typing import ( @@ -814,10 +815,7 @@ async def _astream_events_implementation_v1( data: EventData = {} log_entry: LogEntry = run_log.state["logs"][path] if log_entry["end_time"] is None: - if log_entry["streamed_output"]: - event_type = "stream" - else: - event_type = "start" + event_type = "stream" if log_entry["streamed_output"] else "start" else: event_type = "end" @@ -983,14 +981,15 @@ async def _astream_events_implementation_v2( yield event continue - if event["run_id"] == first_event_run_id and event["event"].endswith( - "_end" + # If it's the end event corresponding to the root runnable + # we dont include the input in the event since it's guaranteed + # to be included in the first event. + if ( + event["run_id"] == first_event_run_id + and event["event"].endswith("_end") + and "input" in event["data"] ): - # If it's the end event corresponding to the root runnable - # we dont include the input in the event since it's guaranteed - # to be included in the first event. - if "input" in event["data"]: - del event["data"]["input"] + del event["data"]["input"] yield event except asyncio.CancelledError as exc: @@ -1001,7 +1000,5 @@ async def _astream_events_implementation_v2( # Cancel the task if it's still running task.cancel() # Await it anyway, to run any cleanup code, and propagate any exceptions - try: + with contextlib.suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index e6422b0682d..b4f5b4593b9 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import copy import threading from collections import defaultdict @@ -253,18 +254,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): """ async for chunk in output: # root run is handled in .astream_log() - if run_id != self.root_id: - # if we can't find the run silently ignore - # eg. because this run wasn't included in the log - if key := self._key_map_by_run_id.get(run_id): - if not self.send( + # if we can't find the run silently ignore + # eg. because this run wasn't included in the log + if ( + run_id != self.root_id + and (key := self._key_map_by_run_id.get(run_id)) + and ( + not self.send( { "op": "add", "path": f"/logs/{key}/streamed_output/-", "value": chunk, } - ): - break + ) + ) + ): + break yield chunk @@ -280,18 +285,22 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): """ for chunk in output: # root run is handled in .astream_log() - if run_id != self.root_id: - # if we can't find the run silently ignore - # eg. because this run wasn't included in the log - if key := self._key_map_by_run_id.get(run_id): - if not self.send( + # if we can't find the run silently ignore + # eg. because this run wasn't included in the log + if ( + run_id != self.root_id + and (key := self._key_map_by_run_id.get(run_id)) + and ( + not self.send( { "op": "add", "path": f"/logs/{key}/streamed_output/-", "value": chunk, } - ): - break + ) + ) + ): + break yield chunk @@ -439,9 +448,8 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): self.send(*ops) finally: - if run.id == self.root_id: - if self.auto_close: - self.send_stream.close() + if run.id == self.root_id and self.auto_close: + self.send_stream.close() def _on_llm_new_token( self, @@ -662,7 +670,5 @@ async def _astream_log_implementation( yield state finally: # Wait for the runnable to finish, if not cancelled (eg. by break) - try: + with contextlib.suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass diff --git a/libs/core/langchain_core/utils/_merge.py b/libs/core/langchain_core/utils/_merge.py index 36e823c9144..32261399d00 100644 --- a/libs/core/langchain_core/utils/_merge.py +++ b/libs/core/langchain_core/utils/_merge.py @@ -29,9 +29,7 @@ def merge_dicts(left: dict[str, Any], *others: dict[str, Any]) -> dict[str, Any] merged = left.copy() for right in others: for right_k, right_v in right.items(): - if right_k not in merged: - merged[right_k] = right_v - elif right_v is not None and merged[right_k] is None: + if right_k not in merged or right_v is not None and merged[right_k] is None: merged[right_k] = right_v elif right_v is None: continue diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index ee03ce9d3a9..8319c9abb98 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -43,14 +43,10 @@ def get_from_dict_or_env( if k in data and data[k]: return data[k] - if isinstance(key, str): - if key in data and data[key]: - return data[key] + if isinstance(key, str) and key in data and data[key]: + return data[key] - if isinstance(key, (list, tuple)): - key_for_err = key[0] - else: - key_for_err = key + key_for_err = key[0] if isinstance(key, (list, tuple)) else key return get_from_env(key_for_err, env_key, default=default) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index f8227a06a6f..d93f5cb0098 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -64,7 +64,7 @@ def _rm_titles(kv: dict, prev_key: str = "") -> dict: new_kv = {} for k, v in kv.items(): if k == "title": - if isinstance(v, dict) and prev_key == "properties" and "title" in v.keys(): + if isinstance(v, dict) and prev_key == "properties" and "title" in v: new_kv[k] = _rm_titles(v, k) else: continue diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index c721f5fcae3..0ef93c05abe 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -139,11 +139,8 @@ def parse_json_markdown( match = _json_markdown_re.search(json_string) # If no match found, assume the entire string is a JSON string - if match is None: - json_str = json_string - else: - # If match found, use the content within the backticks - json_str = match.group(2) + # Else, use the content within the backticks + json_str = json_string if match is None else match.group(2) return _parse_json(json_str, parser=parser) diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py index a64ec346c5c..56538d945f6 100644 --- a/libs/core/langchain_core/utils/mustache.py +++ b/libs/core/langchain_core/utils/mustache.py @@ -80,12 +80,9 @@ def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: padding = literal.split("\n")[-1] # If all the characters since the last newline are spaces - if padding.isspace() or padding == "": - # Then the next tag could be a standalone - return True - else: - # Otherwise it can't be - return False + # Then the next tag could be a standalone + # Otherwise it can't be + return padding.isspace() or padding == "" else: return False @@ -107,10 +104,7 @@ def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: on_newline = template.split("\n", 1) # If the stuff to the right of us are spaces we're a standalone - if on_newline[0].isspace() or not on_newline[0]: - return True - else: - return False + return on_newline[0].isspace() or not on_newline[0] # If we're a tag can't be a standalone else: @@ -174,14 +168,18 @@ def parse_tag(template: str, l_del: str, r_del: str) -> tuple[tuple[str, str], s "unclosed set delimiter tag\n" f"at line {_CURRENT_LINE}" ) - # If we might be a no html escape tag - elif tag_type == "no escape?": + elif ( + # If we might be a no html escape tag + tag_type == "no escape?" # And we have a third curly brace # (And are using curly braces as delimiters) - if l_del == "{{" and r_del == "}}" and template.startswith("}"): - # Then we are a no html escape tag - template = template[1:] - tag_type = "no escape" + and l_del == "{{" + and r_del == "}}" + and template.startswith("}") + ): + # Then we are a no html escape tag + template = template[1:] + tag_type = "no escape" # Strip the whitespace off the key and return return ((tag_type, tag.strip()), template) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index b7d03651667..1352bcfafff 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -179,22 +179,27 @@ def pre_init(func: Callable) -> Any: for name, field_info in fields.items(): # Check if allow_population_by_field_name is enabled # If yes, then set the field name to the alias - if hasattr(cls, "Config"): - if hasattr(cls.Config, "allow_population_by_field_name"): - if cls.Config.allow_population_by_field_name: - if field_info.alias in values: - values[name] = values.pop(field_info.alias) - if hasattr(cls, "model_config"): - if cls.model_config.get("populate_by_name"): - if field_info.alias in values: - values[name] = values.pop(field_info.alias) + if ( + hasattr(cls, "Config") + and hasattr(cls.Config, "allow_population_by_field_name") + and cls.Config.allow_population_by_field_name + and field_info.alias in values + ): + values[name] = values.pop(field_info.alias) + if ( + hasattr(cls, "model_config") + and cls.model_config.get("populate_by_name") + and field_info.alias in values + ): + values[name] = values.pop(field_info.alias) - if name not in values or values[name] is None: - if not field_info.is_required(): - if field_info.default_factory is not None: - values[name] = field_info.default_factory() - else: - values[name] = field_info.default + if ( + name not in values or values[name] is None + ) and not field_info.is_required(): + if field_info.default_factory is not None: + values[name] = field_info.default_factory() + else: + values[name] = field_info.default # Call the decorated function return func(cls, values) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 8236e7fba10..e8d7b34bd26 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -332,9 +332,8 @@ def from_env( for k in key: if k in os.environ: return os.environ[k] - if isinstance(key, str): - if key in os.environ: - return os.environ[key] + if isinstance(key, str) and key in os.environ: + return os.environ[key] if isinstance(default, (str, type(None))): return default @@ -395,9 +394,8 @@ def secret_from_env( for k in key: if k in os.environ: return SecretStr(os.environ[k]) - if isinstance(key, str): - if key in os.environ: - return SecretStr(os.environ[key]) + if isinstance(key, str) and key in os.environ: + return SecretStr(os.environ[key]) if isinstance(default, str): return SecretStr(default) elif isinstance(default, type(None)): diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 9bb2146b7bc..6991754029c 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -44,7 +44,7 @@ python = ">=3.12.4" [tool.poetry.extras] [tool.ruff.lint] -select = [ "B", "C4", "E", "F", "I", "N", "PIE", "T201", "UP",] +select = [ "B", "C4", "E", "F", "I", "N", "PIE", "SIM", "T201", "UP",] ignore = [ "UP007",] [tool.coverage.run] diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 6a7adff650b..b8bdd0ce252 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -148,7 +148,7 @@ def test_pydantic_output_parser() -> None: result = pydantic_parser.parse(DEF_RESULT) print("parse_result:", result) # noqa: T201 - assert DEF_EXPECTED_RESULT == result + assert result == DEF_EXPECTED_RESULT assert pydantic_parser.OutputType is TestModel diff --git a/libs/core/tests/unit_tests/pydantic_utils.py b/libs/core/tests/unit_tests/pydantic_utils.py index 9f78feb3f5e..e64e1e24dfe 100644 --- a/libs/core/tests/unit_tests/pydantic_utils.py +++ b/libs/core/tests/unit_tests/pydantic_utils.py @@ -115,10 +115,7 @@ def _normalize_schema(obj: Any) -> dict[str, Any]: Args: obj: The object to generate the schema for """ - if isinstance(obj, BaseModel): - data = obj.model_json_schema() - else: - data = obj + data = obj.model_json_schema() if isinstance(obj, BaseModel) else obj remove_all_none_default(data) replace_all_of_with_ref(data) _remove_enum(data) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 39bafbd6e51..09de13b36bd 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3345,9 +3345,9 @@ def test_bind_with_lambda() -> None: return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) - assert 4 == runnable.invoke({}) + assert runnable.invoke({}) == 4 chunks = list(runnable.stream({})) - assert [4] == chunks + assert chunks == [4] async def test_bind_with_lambda_async() -> None: @@ -3355,9 +3355,9 @@ async def test_bind_with_lambda_async() -> None: return 3 + kwargs.get("n", 0) runnable = RunnableLambda(my_function).bind(n=1) - assert 4 == await runnable.ainvoke({}) + assert await runnable.ainvoke({}) == 4 chunks = [item async for item in runnable.astream({})] - assert [4] == chunks + assert chunks == [4] def test_deep_stream() -> None: @@ -5140,13 +5140,10 @@ async def test_astream_log_deep_copies() -> None: chain = RunnableLambda(add_one) chunks = [] - final_output = None + final_output: Optional[RunLogPatch] = None async for chunk in chain.astream_log(1): chunks.append(chunk) - if final_output is None: - final_output = chunk - else: - final_output = final_output + chunk + final_output = chunk if final_output is None else final_output + chunk run_log = _get_run_log(chunks) state = run_log.state.copy() diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 3054b04f89b..6cf8b8ea612 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -209,9 +209,11 @@ def test_tracing_enable_disable( get_env_var.cache_clear() env_on = env == "true" - with patch.dict("os.environ", {"LANGSMITH_TRACING": env}): - with tracing_context(enabled=enabled): - RunnableLambda(my_func).invoke(1) + with ( + patch.dict("os.environ", {"LANGSMITH_TRACING": env}), + tracing_context(enabled=enabled), + ): + RunnableLambda(my_func).invoke(1) mock_posts = _get_posts(mock_client_) if enabled is True: diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 9d659d6bd8c..1a6e0f8f9ed 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -1,6 +1,6 @@ import unittest import uuid -from typing import Union +from typing import Optional, Union import pytest @@ -10,6 +10,7 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, ChatMessageChunk, FunctionMessage, @@ -630,14 +631,11 @@ def test_tool_calls_merge() -> None: {"content": ""}, ] - final = None + final: Optional[BaseMessageChunk] = None for chunk in chunks: msg = AIMessageChunk(**chunk) - if final is None: - final = msg - else: - final = final + msg + final = msg if final is None else final + msg assert final == AIMessageChunk( content="", diff --git a/libs/core/tests/unit_tests/tracers/test_langchain.py b/libs/core/tests/unit_tests/tracers/test_langchain.py index 7e03717bd08..d71783f6bc8 100644 --- a/libs/core/tests/unit_tests/tracers/test_langchain.py +++ b/libs/core/tests/unit_tests/tracers/test_langchain.py @@ -133,25 +133,24 @@ class LangChainProjectNameTest(unittest.TestCase): for case in cases: get_env_var.cache_clear() get_tracer_project.cache_clear() - with self.subTest(msg=case.test_name): - with pytest.MonkeyPatch.context() as mp: - for k, v in case.envvars.items(): - mp.setenv(k, v) + with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp: + for k, v in case.envvars.items(): + mp.setenv(k, v) - client = unittest.mock.MagicMock(spec=Client) - tracer = LangChainTracer(client=client) - projects = [] + client = unittest.mock.MagicMock(spec=Client) + tracer = LangChainTracer(client=client) + projects = [] - def mock_create_run(**kwargs: Any) -> Any: - projects.append(kwargs.get("project_name")) # noqa: B023 - return unittest.mock.MagicMock() + def mock_create_run(**kwargs: Any) -> Any: + projects.append(kwargs.get("project_name")) # noqa: B023 + return unittest.mock.MagicMock() - client.create_run = mock_create_run + client.create_run = mock_create_run - tracer.on_llm_start( - {"name": "example_1"}, - ["foo"], - run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"), - ) - tracer.wait_for_futures() - assert projects == [case.expected_project_name] + tracer.on_llm_start( + {"name": "example_1"}, + ["foo"], + run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"), + ) + tracer.wait_for_futures() + assert projects == [case.expected_project_name] diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index 01943ab4db3..bd4694e0d52 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -513,14 +513,8 @@ def test_tool_outputs() -> None: def test__convert_typed_dict_to_openai_function( use_extension_typed_dict: bool, use_extension_annotated: bool ) -> None: - if use_extension_typed_dict: - typed_dict = ExtensionsTypedDict - else: - typed_dict = TypingTypedDict - if use_extension_annotated: - annotated = TypingAnnotated - else: - annotated = TypingAnnotated + typed_dict = ExtensionsTypedDict if use_extension_typed_dict else TypingTypedDict + annotated = TypingAnnotated if use_extension_annotated else TypingAnnotated class SubTool(typed_dict): """Subtool docstring""" diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 2bd897db606..68ad11b3e96 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -116,10 +116,7 @@ def test_check_package_version( def test_merge_dicts( left: dict, right: dict, expected: Union[dict, AbstractContextManager] ) -> None: - if isinstance(expected, AbstractContextManager): - err = expected - else: - err = nullcontext() + err = expected if isinstance(expected, AbstractContextManager) else nullcontext() left_copy = deepcopy(left) right_copy = deepcopy(right) @@ -147,10 +144,7 @@ def test_merge_dicts( def test_merge_dicts_0_3( left: dict, right: dict, expected: Union[dict, AbstractContextManager] ) -> None: - if isinstance(expected, AbstractContextManager): - err = expected - else: - err = nullcontext() + err = expected if isinstance(expected, AbstractContextManager) else nullcontext() left_copy = deepcopy(left) right_copy = deepcopy(right)