diff --git a/libs/langchain/langchain/_api/module_import.py b/libs/langchain/langchain/_api/module_import.py index fdfc1cee08a..f196a86247c 100644 --- a/libs/langchain/langchain/_api/module_import.py +++ b/libs/langchain/langchain/_api/module_import.py @@ -84,28 +84,28 @@ def create_importer( not is_interactive_env() and deprecated_lookups and name in deprecated_lookups - ): # Depth 3: # internal.py # module_import.py # Module in langchain that uses this function # [calling code] whose frame we want to inspect. - if not internal.is_caller_internal(depth=3): - warn_deprecated( - since="0.1", - pending=False, - removal="1.0", - message=( - f"Importing {name} from {package} is deprecated. " - f"Please replace deprecated imports:\n\n" - f">> from {package} import {name}\n\n" - "with new imports of:\n\n" - f">> from {new_module} import {name}\n" - "You can use the langchain cli to **automatically** " - "upgrade many imports. Please see documentation here " - "" - ), - ) + and not internal.is_caller_internal(depth=3) + ): + warn_deprecated( + since="0.1", + pending=False, + removal="1.0", + message=( + f"Importing {name} from {package} is deprecated. " + f"Please replace deprecated imports:\n\n" + f">> from {package} import {name}\n\n" + "with new imports of:\n\n" + f">> from {new_module} import {name}\n" + "You can use the langchain cli to **automatically** " + "upgrade many imports. Please see documentation here " + "" + ), + ) return result except Exception as e: msg = f"module {new_module} has no attribute {name}" @@ -115,28 +115,30 @@ def create_importer( try: module = importlib.import_module(fallback_module) result = getattr(module, name) - if not is_interactive_env(): + if ( + not is_interactive_env() # Depth 3: # internal.py # module_import.py # Module in langchain that uses this function # [calling code] whose frame we want to inspect. - if not internal.is_caller_internal(depth=3): - warn_deprecated( - since="0.1", - pending=False, - removal="1.0", - message=( - f"Importing {name} from {package} is deprecated. " - f"Please replace deprecated imports:\n\n" - f">> from {package} import {name}\n\n" - "with new imports of:\n\n" - f">> from {fallback_module} import {name}\n" - "You can use the langchain cli to **automatically** " - "upgrade many imports. Please see documentation here " - "" - ), - ) + and not internal.is_caller_internal(depth=3) + ): + warn_deprecated( + since="0.1", + pending=False, + removal="1.0", + message=( + f"Importing {name} from {package} is deprecated. " + f"Please replace deprecated imports:\n\n" + f">> from {package} import {name}\n\n" + "with new imports of:\n\n" + f">> from {fallback_module} import {name}\n" + "You can use the langchain cli to **automatically** " + "upgrade many imports. Please see documentation here " + "" + ), + ) return result except Exception as e: diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index ba085daafd0..74f6d2af913 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import builtins +import contextlib import json import logging import time @@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel): agent.agent.save(file_path="path/agent.yaml") """ # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path + save_path = Path(file_path) if isinstance(file_path, str) else file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) @@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel): def dict(self, **kwargs: Any) -> builtins.dict: """Return dictionary representation of agent.""" _dict = super().model_dump() - try: + with contextlib.suppress(NotImplementedError): _dict["_type"] = str(self._agent_type) - except NotImplementedError: - pass return _dict def save(self, file_path: Union[Path, str]) -> None: @@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel): agent.agent.save(file_path="path/agent.yaml") """ # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path + save_path = Path(file_path) if isinstance(file_path, str) else file_path # Fetch dictionary to save agent_dict = self.dict() @@ -1133,13 +1126,14 @@ class AgentExecutor(Chain): agent = self.agent tools = self.tools allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr] - if allowed_tools is not None: - if set(allowed_tools) != set([tool.name for tool in tools]): - msg = ( - f"Allowed tools ({allowed_tools}) different than " - f"provided tools ({[tool.name for tool in tools]})" - ) - raise ValueError(msg) + if allowed_tools is not None and set(allowed_tools) != set( + [tool.name for tool in tools] + ): + msg = ( + f"Allowed tools ({allowed_tools}) different than " + f"provided tools ({[tool.name for tool in tools]})" + ) + raise ValueError(msg) return self @model_validator(mode="before") @@ -1272,13 +1266,7 @@ class AgentExecutor(Chain): def _should_continue(self, iterations: int, time_elapsed: float) -> bool: if self.max_iterations is not None and iterations >= self.max_iterations: return False - if ( - self.max_execution_time is not None - and time_elapsed >= self.max_execution_time - ): - return False - - return True + return self.max_execution_time is None or time_elapsed < self.max_execution_time def _return( self, @@ -1410,10 +1398,7 @@ class AgentExecutor(Chain): return actions: list[AgentAction] - if isinstance(output, AgentAction): - actions = [output] - else: - actions = output + actions = [output] if isinstance(output, AgentAction) else output for agent_action in actions: yield agent_action for agent_action in actions: @@ -1547,10 +1532,7 @@ class AgentExecutor(Chain): return actions: list[AgentAction] - if isinstance(output, AgentAction): - actions = [output] - else: - actions = output + actions = [output] if isinstance(output, AgentAction) else output for agent_action in actions: yield agent_action @@ -1728,12 +1710,14 @@ class AgentExecutor(Chain): if len(self._action_agent.return_values) > 0: return_value_key = self._action_agent.return_values[0] # Invalid tools won't be in the map, so we return False. - if agent_action.tool in name_to_tool_map: - if name_to_tool_map[agent_action.tool].return_direct: - return AgentFinish( - {return_value_key: observation}, - "", - ) + if ( + agent_action.tool in name_to_tool_map + and name_to_tool_map[agent_action.tool].return_direct + ): + return AgentFinish( + {return_value_key: observation}, + "", + ) return None def _prepare_intermediate_steps( diff --git a/libs/langchain/langchain/agents/initialize.py b/libs/langchain/langchain/agents/initialize.py index 9490d7c9c35..b606c942e4c 100644 --- a/libs/langchain/langchain/agents/initialize.py +++ b/libs/langchain/langchain/agents/initialize.py @@ -1,5 +1,6 @@ """Load agent.""" +import contextlib from collections.abc import Sequence from typing import Any, Optional @@ -81,11 +82,9 @@ def initialize_agent( agent_obj = load_agent( agent_path, llm=llm, tools=tools, callback_manager=callback_manager ) - try: + with contextlib.suppress(NotImplementedError): # TODO: Add tags from the serialized object directly. tags_.append(agent_obj._agent_type) - except NotImplementedError: - pass else: msg = ( "Somehow both `agent` and `agent_path` are None, this should never happen." diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py index 34369cc0fb3..d9b1df5bd0e 100644 --- a/libs/langchain/langchain/agents/loading.py +++ b/libs/langchain/langchain/agents/loading.py @@ -128,10 +128,7 @@ def _load_agent_from_file( """Load agent from file.""" valid_suffixes = {"json", "yaml"} # Convert file to Path object. - if isinstance(file, str): - file_path = Path(file) - else: - file_path = file + file_path = Path(file) if isinstance(file, str) else file # Load from either json or yaml. if file_path.suffix[1:] == "json": with open(file_path) as f: diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index f277bd93372..fa182072478 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent): """ _prompts = extra_prompt_messages or [] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] - if system_message: - messages = [system_message] - else: - messages = [] + messages = [system_message] if system_message else [] messages.extend( [ diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index cc26a25f079..9a949a1ff3d 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent): """ _prompts = extra_prompt_messages or [] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] - if system_message: - messages = [system_message] - else: - messages = [] + messages = [system_message] if system_message else [] messages.extend( [ diff --git a/libs/langchain/langchain/agents/output_parsers/tools.py b/libs/langchain/langchain/agents/output_parsers/tools.py index ecf912e8389..4461eeced3f 100644 --- a/libs/langchain/langchain/agents/output_parsers/tools.py +++ b/libs/langchain/langchain/agents/output_parsers/tools.py @@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action( # Open AI does not support passing in a JSON array as an argument. function_name = tool_call["name"] _tool_input = tool_call["args"] - if "__arg1" in _tool_input: - tool_input = _tool_input["__arg1"] - else: - tool_input = _tool_input + tool_input = _tool_input.get("__arg1", _tool_input) content_msg = f"responded: {message.content}\n" if message.content else "\n" log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 9ac58e3da96..a0cef42214e 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -1,6 +1,7 @@ """Base interface that all chains should implement.""" import builtins +import contextlib import inspect import json import logging @@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): # -> {"_type": "foo", "verbose": False, ...} """ _dict = super().dict(**kwargs) - try: + with contextlib.suppress(NotImplementedError): _dict["_type"] = self._chain_type - except NotImplementedError: - pass return _dict def save(self, file_path: Union[Path, str]) -> None: @@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC): raise NotImplementedError(msg) # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path + save_path = Path(file_path) if isinstance(file_path, str) else file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index 6879d632226..a16fd0eac54 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain: def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: """Load chain from file.""" # Convert file to Path object. - if isinstance(file, str): - file_path = Path(file) - else: - file_path = file + file_path = Path(file) if isinstance(file, str) else file # Load from either json or yaml. if file_path.suffix == ".json": with open(file_path) as f: diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index e906de8c990..eadf98f9b06 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain): return [self.output_key] def _moderate(self, text: str, results: Any) -> str: - if self.openai_pre_1_0: - condition = results["flagged"] - else: - condition = results.flagged + condition = results["flagged"] if self.openai_pre_1_0 else results.flagged if condition: error_str = "Text was found that violates OpenAI's content policy." if self.error: diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index 62b6ea24fe0..f6002518de7 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -108,22 +108,26 @@ class QueryTransformer(Transformer): def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: if func_name in set(Comparator): - if self.allowed_comparators is not None: - if func_name not in self.allowed_comparators: - msg = ( - f"Received disallowed comparator {func_name}. Allowed " - f"comparators are {self.allowed_comparators}" - ) - raise ValueError(msg) + if ( + self.allowed_comparators is not None + and func_name not in self.allowed_comparators + ): + msg = ( + f"Received disallowed comparator {func_name}. Allowed " + f"comparators are {self.allowed_comparators}" + ) + raise ValueError(msg) return Comparator(func_name) elif func_name in set(Operator): - if self.allowed_operators is not None: - if func_name not in self.allowed_operators: - msg = ( - f"Received disallowed operator {func_name}. Allowed operators" - f" are {self.allowed_operators}" - ) - raise ValueError(msg) + if ( + self.allowed_operators is not None + and func_name not in self.allowed_operators + ): + msg = ( + f"Received disallowed operator {func_name}. Allowed operators" + f" are {self.allowed_operators}" + ) + raise ValueError(msg) return Operator(func_name) else: msg = ( diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py index 50c67e0ab95..d395d317b1a 100644 --- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py @@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness: if not isinstance(llm, BaseChatModel): msg = "Only chat models supported by the current trajectory eval" raise NotImplementedError(msg) - if agent_tools: - prompt = EVAL_CHAT_PROMPT - else: - prompt = TOOL_FREE_EVAL_CHAT_PROMPT + prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT eval_chain = LLMChain(llm=llm, prompt=prompt) return cls( agent_tools=agent_tools, # type: ignore[arg-type] diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py index 06fe6febfe8..ca36237807d 100644 --- a/libs/langchain/langchain/memory/combined.py +++ b/libs/langchain/langchain/memory/combined.py @@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory): def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]: """Check that if memories are of type BaseChatMemory that input keys exist.""" for val in value: - if isinstance(val, BaseChatMemory): - if val.input_key is None: - warnings.warn( - "When using CombinedMemory, " - "input keys should be so the input is known. " - f" Was not set on {val}" - ) + if isinstance(val, BaseChatMemory) and val.input_key is None: + warnings.warn( + "When using CombinedMemory, " + "input keys should be so the input is known. " + f" Was not set on {val}" + ) return value @property diff --git a/libs/langchain/langchain/model_laboratory.py b/libs/langchain/langchain/model_laboratory.py index 30f7c6de1a4..d4552ab86d0 100644 --- a/libs/langchain/langchain/model_laboratory.py +++ b/libs/langchain/langchain/model_laboratory.py @@ -51,11 +51,9 @@ class ModelLaboratory: "Currently only support chains with one output variable, " f"got {chain.output_keys}" ) - raise ValueError(msg) - if names is not None: - if len(names) != len(chains): - msg = "Length of chains does not match length of names." - raise ValueError(msg) + if names is not None and len(names) != len(chains): + msg = "Length of chains does not match length of names." + raise ValueError(msg) self.chains = chains chain_range = [str(i) for i in range(len(self.chains))] self.chain_colors = get_color_mapping(chain_range) @@ -93,10 +91,7 @@ class ModelLaboratory: """ print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201 for i, chain in enumerate(self.chains): - if self.names is not None: - name = self.names[i] - else: - name = str(chain) + name = self.names[i] if self.names is not None else str(chain) print_text(name, end="\n") output = chain.run(text) print_text(output, color=self.chain_colors[str(i)], end="\n\n") diff --git a/libs/langchain/langchain/output_parsers/loading.py b/libs/langchain/langchain/output_parsers/loading.py index 1cb82ae0e9f..34774341ca2 100644 --- a/libs/langchain/langchain/output_parsers/loading.py +++ b/libs/langchain/langchain/output_parsers/loading.py @@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict: Returns: config dict with output parser loaded """ - if "output_parsers" in config: - if config["output_parsers"] is not None: - _config = config["output_parsers"] - output_parser_type = _config["_type"] - if output_parser_type == "regex_parser": - output_parser = RegexParser(**_config) - else: - msg = f"Unsupported output parser {output_parser_type}" - raise ValueError(msg) - config["output_parsers"] = output_parser + if "output_parsers" in config and config["output_parsers"] is not None: + _config = config["output_parsers"] + output_parser_type = _config["_type"] + if output_parser_type == "regex_parser": + output_parser = RegexParser(**_config) + else: + msg = f"Unsupported output parser {output_parser_type}" + raise ValueError(msg) + config["output_parsers"] = output_parser return config diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py index 4efdf3fa73a..12a03385c4b 100644 --- a/libs/langchain/langchain/output_parsers/yaml.py +++ b/libs/langchain/langchain/output_parsers/yaml.py @@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]): try: # Greedy search for 1st yaml candidate. match = re.search(self.pattern, text.strip()) - yaml_str = "" - if match: - yaml_str = match.group("yaml") - else: - # If no backticks were present, try to parse the entire output as yaml. - yaml_str = text + # If no backticks were present, try to parse the entire output as yaml. + yaml_str = match.group("yaml") if match else text json_object = yaml.safe_load(yaml_str) if hasattr(self.pydantic_object, "model_validate"): diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 67911e0fbc1..2826090d6ee 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever): response = await self.llm_chain.ainvoke( {"question": question}, config={"callbacks": run_manager.get_child()} ) - if isinstance(self.llm_chain, LLMChain): - lines = response["text"] - else: - lines = response + lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response if self.verbose: logger.info(f"Generated queries: {lines}") return lines @@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever): response = self.llm_chain.invoke( {"question": question}, config={"callbacks": run_manager.get_child()} ) - if isinstance(self.llm_chain, LLMChain): - lines = response["text"] - else: - lines = response + lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response if self.verbose: logger.info(f"Generated queries: {lines}") return lines diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py index ceb8e6a4187..74b0c296d56 100644 --- a/libs/langchain/langchain/smith/evaluation/config.py +++ b/libs/langchain/langchain/smith/evaluation/config.py @@ -57,9 +57,7 @@ class EvalConfig(BaseModel): """ kwargs = {} for field, val in self: - if field == "evaluator_type": - continue - elif val is None: + if field == "evaluator_type" or val is None: continue kwargs[field] = val return kwargs diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index d6ffed07c07..3425fd5b45e 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" [tool.ruff.lint] -select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"] +select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"] pydocstyle.convention = "google" pyupgrade.keep-runtime-typing = true diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 257fc8e3460..1c8cbcc4917 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -2,6 +2,8 @@ import asyncio import json +import operator +from functools import reduce from itertools import cycle from typing import Any, Optional, Union, cast @@ -520,14 +522,7 @@ async def test_runnable_agent() -> None: assert messages != [] # Aggregate state - run_log = None - - for result in results: - if run_log is None: - run_log = result - else: - # `+` is defined for RunLogPatch - run_log = run_log + result + run_log = reduce(operator.add, results) assert isinstance(run_log, RunLog) diff --git a/libs/langchain/tests/unit_tests/chains/test_memory.py b/libs/langchain/tests/unit_tests/chains/test_memory.py index dae9cc36610..2959f15ae80 100644 --- a/libs/langchain/tests/unit_tests/chains/test_memory.py +++ b/libs/langchain/tests/unit_tests/chains/test_memory.py @@ -17,7 +17,7 @@ def test_simple_memory() -> None: output = memory.load_memory_variables({}) assert output == {"baz": "foo"} - assert ["baz"] == memory.memory_variables + assert memory.memory_variables == ["baz"] @pytest.mark.parametrize( diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index a3d189209c4..1d06e71efb2 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -1,5 +1,6 @@ """Embeddings tests.""" +import contextlib import hashlib import importlib import warnings @@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None: def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None: # "RAISE_EXCEPTION" forces a failure in batch 2 texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] - try: + with contextlib.suppress(ValueError): cache_embeddings_batch.embed_documents(texts) - except ValueError: - pass keys = list(cache_embeddings_batch.document_embedding_store.yield_keys()) # only the first batch of three embeddings should exist assert len(keys) == 3 @@ -122,10 +121,8 @@ async def test_aembed_documents_batch( ) -> None: # "RAISE_EXCEPTION" forces a failure in batch 2 texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] - try: + with contextlib.suppress(ValueError): await cache_embeddings_batch.aembed_documents(texts) - except ValueError: - pass keys = [ key async for key in cache_embeddings_batch.document_embedding_store.ayield_keys() diff --git a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py index 8e49f1f7d06..df13c2e5d5c 100644 --- a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py @@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N reference = "I like apples." result = eval_chain.evaluate_strings(prediction=prediction, reference=reference) assert "score" in result - assert 0 < result["score"] + assert result["score"] > 0 if normalize_score: assert result["score"] < 1.0 diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py index b7498ef278e..e6264c2379e 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py @@ -45,7 +45,7 @@ def test_combining_dict_result() -> None: ] combining_parser = CombiningOutputParser(parsers=parsers) result_dict = combining_parser.parse(DEF_README) - assert DEF_EXPECTED_RESULT == result_dict + assert result_dict == DEF_EXPECTED_RESULT def test_combining_output_parser_output_type() -> None: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py index 1a25570e47a..8a0562e92e4 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py @@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None: def test_pandas_output_parser_col_multi_elem() -> None: expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")} actual_output = parser.parse("column:chicken[0, 1]") - for key in actual_output.keys(): + for key in actual_output: assert expected_output["chicken"].equals(actual_output[key]) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py index cabf12b5e8a..a898eadfa5b 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_regex.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py @@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None: output_keys=["confidence", "explanation"], default_output_key="noConfidence", ) - assert DEF_EXPECTED_RESULT == parser.parse(DEF_README) + assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT def test_regex_parser_output_type() -> None: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py index 5a604089398..88f1b5432df 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py @@ -35,7 +35,7 @@ def test_regex_dict_result() -> None: ) result_dict = regex_dict_parser.parse(DEF_README) print("parse_result:", result_dict) # noqa: T201 - assert DEF_EXPECTED_RESULT == result_dict + assert result_dict == DEF_EXPECTED_RESULT def test_regex_dict_output_type() -> None: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py index fbf86f2305a..4e6cba09cad 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py @@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None: model = yaml_parser.parse(result) print("parse_result:", result) # noqa: T201 - assert DEF_EXPECTED_RESULT == model + assert model == DEF_EXPECTED_RESULT def test_yaml_output_parser_fail() -> None: diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py index f9d970f19fc..2f1f39ced01 100644 --- a/libs/langchain/tests/unit_tests/test_imports.py +++ b/libs/langchain/tests/unit_tests/test_imports.py @@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]: for node in ast.walk(tree): if isinstance(node, ast.Assign): for target in node.targets: - if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP": - if isinstance(node.value, ast.Dict): - return _dict_from_ast(node.value) + if ( + isinstance(target, ast.Name) + and target.id == "DEPRECATED_LOOKUP" + and isinstance(node.value, ast.Dict) + ): + return _dict_from_ast(node.value) return None @@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str: Returns: str: The corresponding string value. """ - if isinstance(node, ast.Constant): # Python 3.8+ - if isinstance(node.value, str): - return node.value + if ( + isinstance(node, ast.Constant) # Python 3.8+ + and isinstance(node.value, str) + ): + return node.value msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}" raise AssertionError(msg)