diff --git a/libs/langchain/langchain/__init__.py b/libs/langchain/langchain/__init__.py
index edc20ad2b3a..2d441e75ced 100644
--- a/libs/langchain/langchain/__init__.py
+++ b/libs/langchain/langchain/__init__.py
@@ -357,7 +357,7 @@ def __getattr__(name: str) -> Any:
return ElasticVectorSearch
# For backwards compatibility
- if name == "SerpAPIChain" or name == "SerpAPIWrapper":
+ if name in {"SerpAPIChain", "SerpAPIWrapper"}:
from langchain_community.utilities import SerpAPIWrapper
_warn_on_import(
diff --git a/libs/langchain/langchain/agents/format_scratchpad/xml.py b/libs/langchain/langchain/agents/format_scratchpad/xml.py
index 3fb73637790..942b2ec8fa9 100644
--- a/libs/langchain/langchain/agents/format_scratchpad/xml.py
+++ b/libs/langchain/langchain/agents/format_scratchpad/xml.py
@@ -40,13 +40,13 @@ def format_xml(
# Escape XML tags in tool names and inputs using custom delimiters
tool = _escape(action.tool)
tool_input = _escape(str(action.tool_input))
- observation = _escape(str(observation))
+ observation_ = _escape(str(observation))
else:
tool = action.tool
tool_input = str(action.tool_input)
- observation = str(observation)
+ observation_ = str(observation)
log += (
f"{tool}{tool_input}"
- f"{observation}"
+ f"{observation_}"
)
return log
diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py
index 5a014069133..c1e862b1d67 100644
--- a/libs/langchain/langchain/agents/openai_assistant/base.py
+++ b/libs/langchain/langchain/agents/openai_assistant/base.py
@@ -582,7 +582,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or (
- major_version == 1 and minor_version >= 14
+ major_version == 1 and minor_version >= 14 # noqa: PLR2004
)
messages = self.client.beta.threads.messages.list(
@@ -739,7 +739,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
major_version = int(openai.version.VERSION.split(".")[0])
minor_version = int(openai.version.VERSION.split(".")[1])
version_gte_1_14 = (major_version > 1) or (
- major_version == 1 and minor_version >= 14
+ major_version == 1 and minor_version >= 14 # noqa: PLR2004
)
messages = await self.async_client.beta.threads.messages.list(
diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py
index 7430e639650..33c3099815c 100644
--- a/libs/langchain/langchain/agents/react/base.py
+++ b/libs/langchain/langchain/agents/react/base.py
@@ -24,6 +24,9 @@ if TYPE_CHECKING:
from langchain_community.docstore.base import Docstore
+_LOOKUP_AND_SEARCH_TOOLS = {"Lookup", "Search"}
+
+
@deprecated(
"0.1.0",
message=AGENT_DEPRECATION_WARNING,
@@ -52,11 +55,11 @@ class ReActDocstoreAgent(Agent):
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
validate_tools_single_input(cls.__name__, tools)
super()._validate_tools(tools)
- if len(tools) != 2:
+ if len(tools) != len(_LOOKUP_AND_SEARCH_TOOLS):
msg = f"Exactly two tools must be specified, but got {tools}"
raise ValueError(msg)
tool_names = {tool.name for tool in tools}
- if tool_names != {"Lookup", "Search"}:
+ if tool_names != _LOOKUP_AND_SEARCH_TOOLS:
msg = f"Tool names should be Lookup and Search, got {tool_names}"
raise ValueError(msg)
diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py
index df2f9623e90..454993c068d 100644
--- a/libs/langchain/langchain/chains/api/base.py
+++ b/libs/langchain/langchain/chains/api/base.py
@@ -47,8 +47,8 @@ def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains:
- allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
- if scheme == allowed_scheme and domain == allowed_domain:
+ allowed_scheme, allowed_domain_ = _extract_scheme_and_domain(allowed_domain)
+ if scheme == allowed_scheme and domain == allowed_domain_:
return True
return False
diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py
index 8bd373ebd82..bab85195cd9 100644
--- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py
+++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py
@@ -193,13 +193,12 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
- else:
- if values["document_variable_name"] not in llm_chain_variables:
- msg = (
- f"document_variable_name {values['document_variable_name']} was "
- f"not found in llm_chain input_variables: {llm_chain_variables}"
- )
- raise ValueError(msg)
+ elif values["document_variable_name"] not in llm_chain_variables:
+ msg = (
+ f"document_variable_name {values['document_variable_name']} was "
+ f"not found in llm_chain input_variables: {llm_chain_variables}"
+ )
+ raise ValueError(msg)
return values
@property
diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py
index c6ad0c611c7..743c2f7282b 100644
--- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py
+++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py
@@ -161,13 +161,12 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
- else:
- if values["document_variable_name"] not in llm_chain_variables:
- msg = (
- f"document_variable_name {values['document_variable_name']} was "
- f"not found in llm_chain input_variables: {llm_chain_variables}"
- )
- raise ValueError(msg)
+ elif values["document_variable_name"] not in llm_chain_variables:
+ msg = (
+ f"document_variable_name {values['document_variable_name']} was "
+ f"not found in llm_chain input_variables: {llm_chain_variables}"
+ )
+ raise ValueError(msg)
return values
def combine_docs(
diff --git a/libs/langchain/langchain/chains/combine_documents/reduce.py b/libs/langchain/langchain/chains/combine_documents/reduce.py
index eb8648685bb..18c2cba7b84 100644
--- a/libs/langchain/langchain/chains/combine_documents/reduce.py
+++ b/libs/langchain/langchain/chains/combine_documents/reduce.py
@@ -325,10 +325,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max,
**kwargs,
)
- result_docs = []
- for docs in new_result_doc_list:
- new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
- result_docs.append(new_doc)
+ result_docs = [
+ collapse_docs(docs_, _collapse_docs_func, **kwargs)
+ for docs_ in new_result_doc_list
+ ]
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:
@@ -364,10 +364,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
_token_max,
**kwargs,
)
- result_docs = []
- for docs in new_result_doc_list:
- new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
- result_docs.append(new_doc)
+ result_docs = [
+ await acollapse_docs(docs_, _collapse_docs_func, **kwargs)
+ for docs_ in new_result_doc_list
+ ]
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:
diff --git a/libs/langchain/langchain/chains/combine_documents/refine.py b/libs/langchain/langchain/chains/combine_documents/refine.py
index b904b909191..5e5e3e74d05 100644
--- a/libs/langchain/langchain/chains/combine_documents/refine.py
+++ b/libs/langchain/langchain/chains/combine_documents/refine.py
@@ -140,13 +140,12 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain input_variables"
)
raise ValueError(msg)
- else:
- if values["document_variable_name"] not in llm_chain_variables:
- msg = (
- f"document_variable_name {values['document_variable_name']} was "
- f"not found in llm_chain input_variables: {llm_chain_variables}"
- )
- raise ValueError(msg)
+ elif values["document_variable_name"] not in llm_chain_variables:
+ msg = (
+ f"document_variable_name {values['document_variable_name']} was "
+ f"not found in llm_chain input_variables: {llm_chain_variables}"
+ )
+ raise ValueError(msg)
return values
def combine_docs(
diff --git a/libs/langchain/langchain/chains/combine_documents/stuff.py b/libs/langchain/langchain/chains/combine_documents/stuff.py
index 138629ab143..0b380dd8bf6 100644
--- a/libs/langchain/langchain/chains/combine_documents/stuff.py
+++ b/libs/langchain/langchain/chains/combine_documents/stuff.py
@@ -180,13 +180,12 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"multiple llm_chain_variables"
)
raise ValueError(msg)
- else:
- if values["document_variable_name"] not in llm_chain_variables:
- msg = (
- f"document_variable_name {values['document_variable_name']} was "
- f"not found in llm_chain input_variables: {llm_chain_variables}"
- )
- raise ValueError(msg)
+ elif values["document_variable_name"] not in llm_chain_variables:
+ msg = (
+ f"document_variable_name {values['document_variable_name']} was "
+ f"not found in llm_chain input_variables: {llm_chain_variables}"
+ )
+ raise ValueError(msg)
return values
@property
diff --git a/libs/langchain/langchain/chains/natbot/crawler.py b/libs/langchain/langchain/chains/natbot/crawler.py
index ab5348d40c7..00a5dd38a6b 100644
--- a/libs/langchain/langchain/chains/natbot/crawler.py
+++ b/libs/langchain/langchain/chains/natbot/crawler.py
@@ -322,7 +322,7 @@ class Crawler:
if node_name == "#text" and ancestor_exception and ancestor_node:
text = strings[node_value[index]]
- if text == "|" or text == "•":
+ if text in {"|", "•"}:
continue
ancestor_node.append({"type": "type", "value": text})
else:
@@ -367,7 +367,7 @@ class Crawler:
element_node_value = strings[text_index]
# remove redundant elements
- if ancestor_exception and (node_name != "a" and node_name != "button"):
+ if ancestor_exception and (node_name not in {"a", "button"}):
continue
elements_in_view_port.append(
@@ -423,10 +423,7 @@ class Crawler:
# not very elegant, more like a placeholder
if (
(converted_node_name != "button" or meta == "")
- and converted_node_name != "link"
- and converted_node_name != "input"
- and converted_node_name != "img"
- and converted_node_name != "textarea"
+ and converted_node_name not in {"link", "input", "img", "textarea"}
) and inner_text.strip() == "":
continue
diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py
index fe7a265d98c..a981d9fddb5 100644
--- a/libs/langchain/langchain/chains/openai_functions/openapi.py
+++ b/libs/langchain/langchain/chains/openai_functions/openapi.py
@@ -51,13 +51,12 @@ def _format_url(url: str, path_params: dict) -> str:
sep = ","
new_val = ""
new_val += sep.join(kv_strs)
+ elif param[0] == ".":
+ new_val = f".{val}"
+ elif param[0] == ";":
+ new_val = f";{clean_param}={val}"
else:
- if param[0] == ".":
- new_val = f".{val}"
- elif param[0] == ";":
- new_val = f";{clean_param}={val}"
- else:
- new_val = val
+ new_val = val
new_params[param] = new_val
return url.format(**new_params)
@@ -224,7 +223,7 @@ class SimpleRequestChain(Chain):
_text = f"Calling endpoint {_pretty_name} with arguments:\n" + _pretty_args
_run_manager.on_text(_text)
api_response: Response = self.request_method(name, args)
- if api_response.status_code != 200:
+ if api_response.status_code != requests.codes.ok:
response = (
f"{api_response.status_code}: {api_response.reason}"
f"\nFor {name} "
diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py
index 7975ae1f1c9..d377dc181cf 100644
--- a/libs/langchain/langchain/embeddings/__init__.py
+++ b/libs/langchain/langchain/embeddings/__init__.py
@@ -87,7 +87,7 @@ class HypotheticalDocumentEmbedder:
)
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder as H
- return H(*args, **kwargs) # type: ignore[return-value]
+ return H(*args, **kwargs) # type: ignore[return-value] # noqa: PLE0101
@classmethod
def from_llm(cls, *args: Any, **kwargs: Any) -> Any:
diff --git a/libs/langchain/langchain/embeddings/base.py b/libs/langchain/langchain/embeddings/base.py
index c0a6c8b891d..6eb47a5f2d2 100644
--- a/libs/langchain/langchain/embeddings/base.py
+++ b/libs/langchain/langchain/embeddings/base.py
@@ -89,7 +89,6 @@ def _infer_model_and_provider(
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
- provider = provider
model_name = model
if not provider:
diff --git a/libs/langchain/langchain/embeddings/cache.py b/libs/langchain/langchain/embeddings/cache.py
index b42bd959312..fb200c76847 100644
--- a/libs/langchain/langchain/embeddings/cache.py
+++ b/libs/langchain/langchain/embeddings/cache.py
@@ -89,7 +89,7 @@ _warned_about_sha1: bool = False
def _warn_about_sha1_encoder() -> None:
"""Emit a one-time warning about SHA-1 collision weaknesses."""
- global _warned_about_sha1
+ global _warned_about_sha1 # noqa: PLW0603
if not _warned_about_sha1:
warnings.warn(
"Using default key encoder: SHA-1 is *not* collision-resistant. "
diff --git a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
index e31be77d2b8..f80ab15b973 100644
--- a/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
+++ b/libs/langchain/langchain/evaluation/agents/trajectory_eval_chain.py
@@ -36,6 +36,8 @@ from langchain.evaluation.agents.trajectory_eval_prompt import (
)
from langchain.evaluation.schema import AgentTrajectoryEvaluator, LLMEvalChain
+_MAX_SCORE = 5
+
class TrajectoryEval(TypedDict):
"""A named tuple containing the score and reasoning for a trajectory."""
@@ -86,10 +88,10 @@ class TrajectoryOutputParser(BaseOutputParser):
raise OutputParserException(msg)
score = int(_score.group(1))
# If the score is not in the range 1-5, raise an exception.
- if not 1 <= score <= 5:
+ if not 1 <= score <= _MAX_SCORE:
msg = f"Score is not a digit in the range 1-5: {text}"
raise OutputParserException(msg)
- normalized_score = (score - 1) / 4
+ normalized_score = (score - 1) / (_MAX_SCORE - 1)
return TrajectoryEval(score=normalized_score, reasoning=reasoning)
diff --git a/libs/langchain/langchain/globals.py b/libs/langchain/langchain/globals.py
index 7426e06d2fa..34ad94cd0ff 100644
--- a/libs/langchain/langchain/globals.py
+++ b/libs/langchain/langchain/globals.py
@@ -39,7 +39,7 @@ def set_verbose(
# have migrated to using `set_verbose()` here.
langchain.verbose = value
- global _verbose
+ global _verbose # noqa: PLW0603
_verbose = value
@@ -69,7 +69,6 @@ def get_verbose() -> bool:
# directing them to use `set_verbose()` when they import `langchain.verbose`.
old_verbose = langchain.verbose
- global _verbose
return _verbose or old_verbose
@@ -94,7 +93,7 @@ def set_debug(
# have migrated to using `set_debug()` here.
langchain.debug = value
- global _debug
+ global _debug # noqa: PLW0603
_debug = value
@@ -122,7 +121,6 @@ def get_debug() -> bool:
# directing them to use `set_debug()` when they import `langchain.debug`.
old_debug = langchain.debug
- global _debug
return _debug or old_debug
@@ -147,7 +145,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
# once all users have migrated to using `set_llm_cache()` here.
langchain.llm_cache = value
- global _llm_cache
+ global _llm_cache # noqa: PLW0603
_llm_cache = value
@@ -179,5 +177,4 @@ def get_llm_cache() -> "BaseCache":
# to use `set_llm_cache()` when they import `langchain.llm_cache`.
old_llm_cache = langchain.llm_cache
- global _llm_cache
return _llm_cache or old_llm_cache
diff --git a/libs/langchain/langchain/output_parsers/combining.py b/libs/langchain/langchain/output_parsers/combining.py
index ed52a89ce75..bf112818184 100644
--- a/libs/langchain/langchain/output_parsers/combining.py
+++ b/libs/langchain/langchain/output_parsers/combining.py
@@ -5,6 +5,8 @@ from typing import Any
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.utils import pre_init
+_MIN_PARSERS = 2
+
class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
"""Combine multiple output parsers into one."""
@@ -19,7 +21,7 @@ class CombiningOutputParser(BaseOutputParser[dict[str, Any]]):
def validate_parsers(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Validate the parsers."""
parsers = values["parsers"]
- if len(parsers) < 2:
+ if len(parsers) < _MIN_PARSERS:
msg = "Must have at least two parsers"
raise ValueError(msg)
for parser in parsers:
diff --git a/libs/langchain/langchain/output_parsers/pandas_dataframe.py b/libs/langchain/langchain/output_parsers/pandas_dataframe.py
index 78b16ba0910..d360865e324 100644
--- a/libs/langchain/langchain/output_parsers/pandas_dataframe.py
+++ b/libs/langchain/langchain/output_parsers/pandas_dataframe.py
@@ -79,7 +79,7 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
def parse(self, request: str) -> dict[str, Any]:
stripped_request_params = None
splitted_request = request.strip().split(":")
- if len(splitted_request) != 2:
+ if len(splitted_request) != 2: # noqa: PLR2004
msg = f"Request '{request}' is not correctly formatted. \
Please refer to the format instructions."
raise OutputParserException(msg)
@@ -127,16 +127,15 @@ class PandasDataFrameOutputParser(BaseOutputParser[dict[str, Any]]):
filtered_df[stripped_request_params],
request_type,
)()
+ elif request_type == "column":
+ result[request_params] = self.dataframe[request_params]
+ elif request_type == "row":
+ result[request_params] = self.dataframe.iloc[int(request_params)]
else:
- if request_type == "column":
- result[request_params] = self.dataframe[request_params]
- elif request_type == "row":
- result[request_params] = self.dataframe.iloc[int(request_params)]
- else:
- result[request_type] = getattr(
- self.dataframe[request_params],
- request_type,
- )()
+ result[request_type] = getattr(
+ self.dataframe[request_params],
+ request_type,
+ )()
except (AttributeError, IndexError, KeyError) as e:
if request_type not in {"column", "row"}:
msg = f"Unsupported request type '{request_type}'. \
diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py
index b5c6f9d1879..e26d8847c2a 100644
--- a/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py
+++ b/libs/langchain/langchain/retrievers/document_compressors/chain_filter.py
@@ -70,9 +70,8 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
- else:
- if isinstance(output_, bool):
- include_doc = output_
+ elif isinstance(output_, bool):
+ include_doc = output_
if include_doc:
filtered_docs.append(doc)
@@ -101,9 +100,8 @@ class LLMChainFilter(BaseDocumentCompressor):
output = output_[self.llm_chain.output_key]
if self.llm_chain.prompt.output_parser is not None:
include_doc = self.llm_chain.prompt.output_parser.parse(output)
- else:
- if isinstance(output_, bool):
- include_doc = output_
+ elif isinstance(output_, bool):
+ include_doc = output_
if include_doc:
filtered_docs.append(doc)
diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml
index 3b1add6dca5..c237945ea4e 100644
--- a/libs/langchain/pyproject.toml
+++ b/libs/langchain/pyproject.toml
@@ -147,11 +147,9 @@ select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"ASYNC", # flake8-async
- "B", # flake8-bugbear
"C4", # flake8-comprehensions
"COM", # flake8-commas
"D", # pydocstyle
- "DOC", # pydoclint
"DTZ", # flake8-datetimez
"E", # pycodestyle error
"EM", # flake8-errmsg
@@ -164,9 +162,10 @@ select = [
"ICN", # flake8-import-conventions
"INT", # flake8-gettext
"ISC", # isort-comprehensions
+ "PERF", # flake8-perf
"PGH", # pygrep-hooks
"PIE", # flake8-pie
- "PERF", # flake8-perf
+ "PL", # pylint
"PT", # flake8-pytest-style
"PTH", # flake8-use-pathlib
"PYI", # flake8-pyi
@@ -175,9 +174,9 @@ select = [
"RSE", # flake8-rst-docstrings
"RUF", # ruff
"S", # flake8-bandit
- "SLF", # flake8-self
"SLOT", # flake8-slots
"SIM", # flake8-simplify
+ "SLF", # flake8-self
"T10", # flake8-debugger
"T20", # flake8-print
"TID", # flake8-tidy-imports
@@ -198,13 +197,15 @@ ignore = [
"COM812", # Messes with the formatter
"ISC001", # Messes with the formatter
"PERF203", # Rarely useful
+ "PLR09", # Too many something (args, statements, etc)
"S112", # Rarely useful
"RUF012", # Doesn't play well with Pydantic
"SLF001", # Private member access
"UP007", # pyupgrade: non-pep604-annotation-union
- # TODO
- "TRY301", # tryceratops: raise-within-try
+ # TODO rules
+ "PLC0415", # pylint: import-outside-top-level
+ "TRY301", # tryceratops: raise-within-try
]
unfixable = ["B028"] # People should intentionally tune the stacklevel
@@ -217,9 +218,10 @@ pyupgrade.keep-runtime-typing = true
[tool.ruff.lint.extend-per-file-ignores]
"tests/**/*.py" = [
- "S101", # Tests need assertions
- "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
- "SLF001", # Private member access in tests
+ "S101", # Tests need assertions
+ "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
+ "SLF001", # Private member access in tests
+ "PLR2004", # Magic value comparisons
]
"langchain/chains/constitutional_ai/principles.py" = [
"E501", # Line too long
diff --git a/libs/langchain/tests/unit_tests/conftest.py b/libs/langchain/tests/unit_tests/conftest.py
index fe0b62e766f..3a620b411fe 100644
--- a/libs/langchain/tests/unit_tests/conftest.py
+++ b/libs/langchain/tests/unit_tests/conftest.py
@@ -130,8 +130,7 @@ def pytest_collection_modifyitems(
pytest.mark.skip(reason=f"Requires pkg: `{pkg}`"),
)
break
- else:
- if only_extended:
- item.add_marker(
- pytest.mark.skip(reason="Skipping not an extended test."),
- )
+ elif only_extended:
+ item.add_marker(
+ pytest.mark.skip(reason="Skipping not an extended test."),
+ )
diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py
index c9bdea85f28..b492c1b93c6 100644
--- a/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py
+++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_query.py
@@ -1,4 +1,4 @@
-import pytest as pytest
+import pytest
from langchain_core.documents import Document
from langchain.retrievers.multi_query import LineListOutputParser, _unique_documents
diff --git a/libs/langchain/tests/unit_tests/stubs.py b/libs/langchain/tests/unit_tests/stubs.py
index 44450d946bd..742d2df5813 100644
--- a/libs/langchain/tests/unit_tests/stubs.py
+++ b/libs/langchain/tests/unit_tests/stubs.py
@@ -10,6 +10,8 @@ class AnyStr(str):
def __eq__(self, other: object) -> bool:
return isinstance(other, str)
+ __hash__ = str.__hash__
+
# The code below creates version of pydantic models
# that will work in unit tests with AnyStr as id field