diff --git a/libs/langchain/langchain/_api/module_import.py b/libs/langchain/langchain/_api/module_import.py
index fdfc1cee08a..f196a86247c 100644
--- a/libs/langchain/langchain/_api/module_import.py
+++ b/libs/langchain/langchain/_api/module_import.py
@@ -84,28 +84,28 @@ def create_importer(
not is_interactive_env()
and deprecated_lookups
and name in deprecated_lookups
- ):
# Depth 3:
# internal.py
# module_import.py
# Module in langchain that uses this function
# [calling code] whose frame we want to inspect.
- if not internal.is_caller_internal(depth=3):
- warn_deprecated(
- since="0.1",
- pending=False,
- removal="1.0",
- message=(
- f"Importing {name} from {package} is deprecated. "
- f"Please replace deprecated imports:\n\n"
- f">> from {package} import {name}\n\n"
- "with new imports of:\n\n"
- f">> from {new_module} import {name}\n"
- "You can use the langchain cli to **automatically** "
- "upgrade many imports. Please see documentation here "
- ""
- ),
- )
+ and not internal.is_caller_internal(depth=3)
+ ):
+ warn_deprecated(
+ since="0.1",
+ pending=False,
+ removal="1.0",
+ message=(
+ f"Importing {name} from {package} is deprecated. "
+ f"Please replace deprecated imports:\n\n"
+ f">> from {package} import {name}\n\n"
+ "with new imports of:\n\n"
+ f">> from {new_module} import {name}\n"
+ "You can use the langchain cli to **automatically** "
+ "upgrade many imports. Please see documentation here "
+ ""
+ ),
+ )
return result
except Exception as e:
msg = f"module {new_module} has no attribute {name}"
@@ -115,28 +115,30 @@ def create_importer(
try:
module = importlib.import_module(fallback_module)
result = getattr(module, name)
- if not is_interactive_env():
+ if (
+ not is_interactive_env()
# Depth 3:
# internal.py
# module_import.py
# Module in langchain that uses this function
# [calling code] whose frame we want to inspect.
- if not internal.is_caller_internal(depth=3):
- warn_deprecated(
- since="0.1",
- pending=False,
- removal="1.0",
- message=(
- f"Importing {name} from {package} is deprecated. "
- f"Please replace deprecated imports:\n\n"
- f">> from {package} import {name}\n\n"
- "with new imports of:\n\n"
- f">> from {fallback_module} import {name}\n"
- "You can use the langchain cli to **automatically** "
- "upgrade many imports. Please see documentation here "
- ""
- ),
- )
+ and not internal.is_caller_internal(depth=3)
+ ):
+ warn_deprecated(
+ since="0.1",
+ pending=False,
+ removal="1.0",
+ message=(
+ f"Importing {name} from {package} is deprecated. "
+ f"Please replace deprecated imports:\n\n"
+ f">> from {package} import {name}\n\n"
+ "with new imports of:\n\n"
+ f">> from {fallback_module} import {name}\n"
+ "You can use the langchain cli to **automatically** "
+ "upgrade many imports. Please see documentation here "
+ ""
+ ),
+ )
return result
except Exception as e:
diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py
index ba085daafd0..74f6d2af913 100644
--- a/libs/langchain/langchain/agents/agent.py
+++ b/libs/langchain/langchain/agents/agent.py
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import builtins
+import contextlib
import json
import logging
import time
@@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
- if isinstance(file_path, str):
- save_path = Path(file_path)
- else:
- save_path = file_path
+ save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
@@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().model_dump()
- try:
+ with contextlib.suppress(NotImplementedError):
_dict["_type"] = str(self._agent_type)
- except NotImplementedError:
- pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml")
"""
# Convert file to Path object.
- if isinstance(file_path, str):
- save_path = Path(file_path)
- else:
- save_path = file_path
+ save_path = Path(file_path) if isinstance(file_path, str) else file_path
# Fetch dictionary to save
agent_dict = self.dict()
@@ -1133,13 +1126,14 @@ class AgentExecutor(Chain):
agent = self.agent
tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
- if allowed_tools is not None:
- if set(allowed_tools) != set([tool.name for tool in tools]):
- msg = (
- f"Allowed tools ({allowed_tools}) different than "
- f"provided tools ({[tool.name for tool in tools]})"
- )
- raise ValueError(msg)
+ if allowed_tools is not None and set(allowed_tools) != set(
+ [tool.name for tool in tools]
+ ):
+ msg = (
+ f"Allowed tools ({allowed_tools}) different than "
+ f"provided tools ({[tool.name for tool in tools]})"
+ )
+ raise ValueError(msg)
return self
@model_validator(mode="before")
@@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations:
return False
- if (
- self.max_execution_time is not None
- and time_elapsed >= self.max_execution_time
- ):
- return False
-
- return True
+ return self.max_execution_time is None or time_elapsed < self.max_execution_time
def _return(
self,
@@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
return
actions: list[AgentAction]
- if isinstance(output, AgentAction):
- actions = [output]
- else:
- actions = output
+ actions = [output] if isinstance(output, AgentAction) else output
for agent_action in actions:
yield agent_action
for agent_action in actions:
@@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
return
actions: list[AgentAction]
- if isinstance(output, AgentAction):
- actions = [output]
- else:
- actions = output
+ actions = [output] if isinstance(output, AgentAction) else output
for agent_action in actions:
yield agent_action
@@ -1728,12 +1710,14 @@ class AgentExecutor(Chain):
if len(self._action_agent.return_values) > 0:
return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False.
- if agent_action.tool in name_to_tool_map:
- if name_to_tool_map[agent_action.tool].return_direct:
- return AgentFinish(
- {return_value_key: observation},
- "",
- )
+ if (
+ agent_action.tool in name_to_tool_map
+ and name_to_tool_map[agent_action.tool].return_direct
+ ):
+ return AgentFinish(
+ {return_value_key: observation},
+ "",
+ )
return None
def _prepare_intermediate_steps(
diff --git a/libs/langchain/langchain/agents/initialize.py b/libs/langchain/langchain/agents/initialize.py
index 9490d7c9c35..b606c942e4c 100644
--- a/libs/langchain/langchain/agents/initialize.py
+++ b/libs/langchain/langchain/agents/initialize.py
@@ -1,5 +1,6 @@
"""Load agent."""
+import contextlib
from collections.abc import Sequence
from typing import Any, Optional
@@ -81,11 +82,9 @@ def initialize_agent(
agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
)
- try:
+ with contextlib.suppress(NotImplementedError):
# TODO: Add tags from the serialized object directly.
tags_.append(agent_obj._agent_type)
- except NotImplementedError:
- pass
else:
msg = (
"Somehow both `agent` and `agent_path` are None, this should never happen."
diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py
index 34369cc0fb3..d9b1df5bd0e 100644
--- a/libs/langchain/langchain/agents/loading.py
+++ b/libs/langchain/langchain/agents/loading.py
@@ -128,10 +128,7 @@ def _load_agent_from_file(
"""Load agent from file."""
valid_suffixes = {"json", "yaml"}
# Convert file to Path object.
- if isinstance(file, str):
- file_path = Path(file)
- else:
- file_path = file
+ file_path = Path(file) if isinstance(file, str) else file
# Load from either json or yaml.
if file_path.suffix[1:] == "json":
with open(file_path) as f:
diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py
index f277bd93372..fa182072478 100644
--- a/libs/langchain/langchain/agents/openai_functions_agent/base.py
+++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py
@@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
"""
_prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
- if system_message:
- messages = [system_message]
- else:
- messages = []
+ messages = [system_message] if system_message else []
messages.extend(
[
diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py
index cc26a25f079..9a949a1ff3d 100644
--- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py
+++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py
@@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
"""
_prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
- if system_message:
- messages = [system_message]
- else:
- messages = []
+ messages = [system_message] if system_message else []
messages.extend(
[
diff --git a/libs/langchain/langchain/agents/output_parsers/tools.py b/libs/langchain/langchain/agents/output_parsers/tools.py
index ecf912e8389..4461eeced3f 100644
--- a/libs/langchain/langchain/agents/output_parsers/tools.py
+++ b/libs/langchain/langchain/agents/output_parsers/tools.py
@@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
# Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"]
_tool_input = tool_call["args"]
- if "__arg1" in _tool_input:
- tool_input = _tool_input["__arg1"]
- else:
- tool_input = _tool_input
+ tool_input = _tool_input.get("__arg1", _tool_input)
content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py
index 9ac58e3da96..a0cef42214e 100644
--- a/libs/langchain/langchain/chains/base.py
+++ b/libs/langchain/langchain/chains/base.py
@@ -1,6 +1,7 @@
"""Base interface that all chains should implement."""
import builtins
+import contextlib
import inspect
import json
import logging
@@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
# -> {"_type": "foo", "verbose": False, ...}
"""
_dict = super().dict(**kwargs)
- try:
+ with contextlib.suppress(NotImplementedError):
_dict["_type"] = self._chain_type
- except NotImplementedError:
- pass
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
raise NotImplementedError(msg)
# Convert file to Path object.
- if isinstance(file_path, str):
- save_path = Path(file_path)
- else:
- save_path = file_path
+ save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py
index 6879d632226..a16fd0eac54 100644
--- a/libs/langchain/langchain/chains/loading.py
+++ b/libs/langchain/langchain/chains/loading.py
@@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
"""Load chain from file."""
# Convert file to Path object.
- if isinstance(file, str):
- file_path = Path(file)
- else:
- file_path = file
+ file_path = Path(file) if isinstance(file, str) else file
# Load from either json or yaml.
if file_path.suffix == ".json":
with open(file_path) as f:
diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py
index e906de8c990..eadf98f9b06 100644
--- a/libs/langchain/langchain/chains/moderation.py
+++ b/libs/langchain/langchain/chains/moderation.py
@@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
return [self.output_key]
def _moderate(self, text: str, results: Any) -> str:
- if self.openai_pre_1_0:
- condition = results["flagged"]
- else:
- condition = results.flagged
+ condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
if condition:
error_str = "Text was found that violates OpenAI's content policy."
if self.error:
diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py
index 62b6ea24fe0..f6002518de7 100644
--- a/libs/langchain/langchain/chains/query_constructor/parser.py
+++ b/libs/langchain/langchain/chains/query_constructor/parser.py
@@ -108,22 +108,26 @@ class QueryTransformer(Transformer):
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator):
- if self.allowed_comparators is not None:
- if func_name not in self.allowed_comparators:
- msg = (
- f"Received disallowed comparator {func_name}. Allowed "
- f"comparators are {self.allowed_comparators}"
- )
- raise ValueError(msg)
+ if (
+ self.allowed_comparators is not None
+ and func_name not in self.allowed_comparators
+ ):
+ msg = (
+ f"Received disallowed comparator {func_name}. Allowed "
+ f"comparators are {self.allowed_comparators}"
+ )
+ raise ValueError(msg)
return Comparator(func_name)
elif func_name in set(Operator):
- if self.allowed_operators is not None:
- if func_name not in self.allowed_operators:
- msg = (
- f"Received disallowed operator {func_name}. Allowed operators"
- f" are {self.allowed_operators}"
- )
- raise ValueError(msg)
+ if (
+ self.allowed_operators is not None
+ and func_name not in self.allowed_operators
+ ):
+ msg = (
+ f"Received disallowed operator {func_name}. Allowed operators"
+ f" are {self.allowed_operators}"
+ )
+ raise ValueError(msg)
return Operator(func_name)
else:
msg = (
diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
index 50c67e0ab95..d395d317b1a 100644
--- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
+++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
@@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
if not isinstance(llm, BaseChatModel):
msg = "Only chat models supported by the current trajectory eval"
raise NotImplementedError(msg)
- if agent_tools:
- prompt = EVAL_CHAT_PROMPT
- else:
- prompt = TOOL_FREE_EVAL_CHAT_PROMPT
+ prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls(
agent_tools=agent_tools, # type: ignore[arg-type]
diff --git a/libs/langchain/langchain/memory/combined.py b/libs/langchain/langchain/memory/combined.py
index 06fe6febfe8..ca36237807d 100644
--- a/libs/langchain/langchain/memory/combined.py
+++ b/libs/langchain/langchain/memory/combined.py
@@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory):
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value:
- if isinstance(val, BaseChatMemory):
- if val.input_key is None:
- warnings.warn(
- "When using CombinedMemory, "
- "input keys should be so the input is known. "
- f" Was not set on {val}"
- )
+ if isinstance(val, BaseChatMemory) and val.input_key is None:
+ warnings.warn(
+ "When using CombinedMemory, "
+ "input keys should be so the input is known. "
+ f" Was not set on {val}"
+ )
return value
@property
diff --git a/libs/langchain/langchain/model_laboratory.py b/libs/langchain/langchain/model_laboratory.py
index 30f7c6de1a4..d4552ab86d0 100644
--- a/libs/langchain/langchain/model_laboratory.py
+++ b/libs/langchain/langchain/model_laboratory.py
@@ -51,11 +51,9 @@ class ModelLaboratory:
"Currently only support chains with one output variable, "
f"got {chain.output_keys}"
)
- raise ValueError(msg)
- if names is not None:
- if len(names) != len(chains):
- msg = "Length of chains does not match length of names."
- raise ValueError(msg)
+ if names is not None and len(names) != len(chains):
+ msg = "Length of chains does not match length of names."
+ raise ValueError(msg)
self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range)
@@ -93,10 +91,7 @@ class ModelLaboratory:
"""
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
for i, chain in enumerate(self.chains):
- if self.names is not None:
- name = self.names[i]
- else:
- name = str(chain)
+ name = self.names[i] if self.names is not None else str(chain)
print_text(name, end="\n")
output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
diff --git a/libs/langchain/langchain/output_parsers/loading.py b/libs/langchain/langchain/output_parsers/loading.py
index 1cb82ae0e9f..34774341ca2 100644
--- a/libs/langchain/langchain/output_parsers/loading.py
+++ b/libs/langchain/langchain/output_parsers/loading.py
@@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict:
Returns:
config dict with output parser loaded
"""
- if "output_parsers" in config:
- if config["output_parsers"] is not None:
- _config = config["output_parsers"]
- output_parser_type = _config["_type"]
- if output_parser_type == "regex_parser":
- output_parser = RegexParser(**_config)
- else:
- msg = f"Unsupported output parser {output_parser_type}"
- raise ValueError(msg)
- config["output_parsers"] = output_parser
+ if "output_parsers" in config and config["output_parsers"] is not None:
+ _config = config["output_parsers"]
+ output_parser_type = _config["_type"]
+ if output_parser_type == "regex_parser":
+ output_parser = RegexParser(**_config)
+ else:
+ msg = f"Unsupported output parser {output_parser_type}"
+ raise ValueError(msg)
+ config["output_parsers"] = output_parser
return config
diff --git a/libs/langchain/langchain/output_parsers/yaml.py b/libs/langchain/langchain/output_parsers/yaml.py
index 4efdf3fa73a..12a03385c4b 100644
--- a/libs/langchain/langchain/output_parsers/yaml.py
+++ b/libs/langchain/langchain/output_parsers/yaml.py
@@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
try:
# Greedy search for 1st yaml candidate.
match = re.search(self.pattern, text.strip())
- yaml_str = ""
- if match:
- yaml_str = match.group("yaml")
- else:
- # If no backticks were present, try to parse the entire output as yaml.
- yaml_str = text
+ # If no backticks were present, try to parse the entire output as yaml.
+ yaml_str = match.group("yaml") if match else text
json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"):
diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py
index 67911e0fbc1..2826090d6ee 100644
--- a/libs/langchain/langchain/retrievers/multi_query.py
+++ b/libs/langchain/langchain/retrievers/multi_query.py
@@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
response = await self.llm_chain.ainvoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
- if isinstance(self.llm_chain, LLMChain):
- lines = response["text"]
- else:
- lines = response
+ lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
@@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
response = self.llm_chain.invoke(
{"question": question}, config={"callbacks": run_manager.get_child()}
)
- if isinstance(self.llm_chain, LLMChain):
- lines = response["text"]
- else:
- lines = response
+ lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
diff --git a/libs/langchain/langchain/smith/evaluation/config.py b/libs/langchain/langchain/smith/evaluation/config.py
index ceb8e6a4187..74b0c296d56 100644
--- a/libs/langchain/langchain/smith/evaluation/config.py
+++ b/libs/langchain/langchain/smith/evaluation/config.py
@@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
"""
kwargs = {}
for field, val in self:
- if field == "evaluator_type":
- continue
- elif val is None:
+ if field == "evaluator_type" or val is None:
continue
kwargs[field] = val
return kwargs
diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml
index d6ffed07c07..3425fd5b45e 100644
--- a/libs/langchain/pyproject.toml
+++ b/libs/langchain/pyproject.toml
@@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint]
-select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
+select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true
diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py
index 257fc8e3460..1c8cbcc4917 100644
--- a/libs/langchain/tests/unit_tests/agents/test_agent.py
+++ b/libs/langchain/tests/unit_tests/agents/test_agent.py
@@ -2,6 +2,8 @@
import asyncio
import json
+import operator
+from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union, cast
@@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
assert messages != []
# Aggregate state
- run_log = None
-
- for result in results:
- if run_log is None:
- run_log = result
- else:
- # `+` is defined for RunLogPatch
- run_log = run_log + result
+ run_log = reduce(operator.add, results)
assert isinstance(run_log, RunLog)
diff --git a/libs/langchain/tests/unit_tests/chains/test_memory.py b/libs/langchain/tests/unit_tests/chains/test_memory.py
index dae9cc36610..2959f15ae80 100644
--- a/libs/langchain/tests/unit_tests/chains/test_memory.py
+++ b/libs/langchain/tests/unit_tests/chains/test_memory.py
@@ -17,7 +17,7 @@ def test_simple_memory() -> None:
output = memory.load_memory_variables({})
assert output == {"baz": "foo"}
- assert ["baz"] == memory.memory_variables
+ assert memory.memory_variables == ["baz"]
@pytest.mark.parametrize(
diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py
index a3d189209c4..1d06e71efb2 100644
--- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py
+++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py
@@ -1,5 +1,6 @@
"""Embeddings tests."""
+import contextlib
import hashlib
import importlib
import warnings
@@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
- try:
+ with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts)
- except ValueError:
- pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist
assert len(keys) == 3
@@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
- try:
+ with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts)
- except ValueError:
- pass
keys = [
key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
diff --git a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py
index 8e49f1f7d06..df13c2e5d5c 100644
--- a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py
+++ b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py
@@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result
- assert 0 < result["score"]
+ assert result["score"] > 0
if normalize_score:
assert result["score"] < 1.0
diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py
index b7498ef278e..e6264c2379e 100644
--- a/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py
+++ b/libs/langchain/tests/unit_tests/output_parsers/test_combining_parser.py
@@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
]
combining_parser = CombiningOutputParser(parsers=parsers)
result_dict = combining_parser.parse(DEF_README)
- assert DEF_EXPECTED_RESULT == result_dict
+ assert result_dict == DEF_EXPECTED_RESULT
def test_combining_output_parser_output_type() -> None:
diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py
index 1a25570e47a..8a0562e92e4 100644
--- a/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py
+++ b/libs/langchain/tests/unit_tests/output_parsers/test_pandas_dataframe_parser.py
@@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
def test_pandas_output_parser_col_multi_elem() -> None:
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
actual_output = parser.parse("column:chicken[0, 1]")
- for key in actual_output.keys():
+ for key in actual_output:
assert expected_output["chicken"].equals(actual_output[key])
diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py
index cabf12b5e8a..a898eadfa5b 100644
--- a/libs/langchain/tests/unit_tests/output_parsers/test_regex.py
+++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex.py
@@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
output_keys=["confidence", "explanation"],
default_output_key="noConfidence",
)
- assert DEF_EXPECTED_RESULT == parser.parse(DEF_README)
+ assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
def test_regex_parser_output_type() -> None:
diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py
index 5a604089398..88f1b5432df 100644
--- a/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py
+++ b/libs/langchain/tests/unit_tests/output_parsers/test_regex_dict.py
@@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
)
result_dict = regex_dict_parser.parse(DEF_README)
print("parse_result:", result_dict) # noqa: T201
- assert DEF_EXPECTED_RESULT == result_dict
+ assert result_dict == DEF_EXPECTED_RESULT
def test_regex_dict_output_type() -> None:
diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py
index fbf86f2305a..4e6cba09cad 100644
--- a/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py
+++ b/libs/langchain/tests/unit_tests/output_parsers/test_yaml_parser.py
@@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
model = yaml_parser.parse(result)
print("parse_result:", result) # noqa: T201
- assert DEF_EXPECTED_RESULT == model
+ assert model == DEF_EXPECTED_RESULT
def test_yaml_output_parser_fail() -> None:
diff --git a/libs/langchain/tests/unit_tests/test_imports.py b/libs/langchain/tests/unit_tests/test_imports.py
index f9d970f19fc..2f1f39ced01 100644
--- a/libs/langchain/tests/unit_tests/test_imports.py
+++ b/libs/langchain/tests/unit_tests/test_imports.py
@@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
for target in node.targets:
- if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP":
- if isinstance(node.value, ast.Dict):
- return _dict_from_ast(node.value)
+ if (
+ isinstance(target, ast.Name)
+ and target.id == "DEPRECATED_LOOKUP"
+ and isinstance(node.value, ast.Dict)
+ ):
+ return _dict_from_ast(node.value)
return None
@@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
Returns:
str: The corresponding string value.
"""
- if isinstance(node, ast.Constant): # Python 3.8+
- if isinstance(node.value, str):
- return node.value
+ if (
+ isinstance(node, ast.Constant) # Python 3.8+
+ and isinstance(node.value, str)
+ ):
+ return node.value
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
raise AssertionError(msg)