langchain: Add ruff rules SIM (#31881)

See https://docs.astral.sh/ruff/rules/#flake8-simplify-sim

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-07 16:27:04 +02:00 committed by GitHub
parent f06380516f
commit 2df3fdf40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 138 additions and 196 deletions

View File

@ -84,13 +84,13 @@ def create_importer(
not is_interactive_env() not is_interactive_env()
and deprecated_lookups and deprecated_lookups
and name in deprecated_lookups and name in deprecated_lookups
):
# Depth 3: # Depth 3:
# internal.py # internal.py
# module_import.py # module_import.py
# Module in langchain that uses this function # Module in langchain that uses this function
# [calling code] whose frame we want to inspect. # [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3): and not internal.is_caller_internal(depth=3)
):
warn_deprecated( warn_deprecated(
since="0.1", since="0.1",
pending=False, pending=False,
@ -115,13 +115,15 @@ def create_importer(
try: try:
module = importlib.import_module(fallback_module) module = importlib.import_module(fallback_module)
result = getattr(module, name) result = getattr(module, name)
if not is_interactive_env(): if (
not is_interactive_env()
# Depth 3: # Depth 3:
# internal.py # internal.py
# module_import.py # module_import.py
# Module in langchain that uses this function # Module in langchain that uses this function
# [calling code] whose frame we want to inspect. # [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3): and not internal.is_caller_internal(depth=3)
):
warn_deprecated( warn_deprecated(
since="0.1", since="0.1",
pending=False, pending=False,

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import builtins import builtins
import contextlib
import json import json
import logging import logging
import time import time
@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml") agent.agent.save(file_path="path/agent.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().model_dump() _dict = super().model_dump()
try: with contextlib.suppress(NotImplementedError):
_dict["_type"] = str(self._agent_type) _dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml") agent.agent.save(file_path="path/agent.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
# Fetch dictionary to save # Fetch dictionary to save
agent_dict = self.dict() agent_dict = self.dict()
@ -1133,8 +1126,9 @@ class AgentExecutor(Chain):
agent = self.agent agent = self.agent
tools = self.tools tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr] allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
if allowed_tools is not None: if allowed_tools is not None and set(allowed_tools) != set(
if set(allowed_tools) != set([tool.name for tool in tools]): [tool.name for tool in tools]
):
msg = ( msg = (
f"Allowed tools ({allowed_tools}) different than " f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({[tool.name for tool in tools]})" f"provided tools ({[tool.name for tool in tools]})"
@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
def _should_continue(self, iterations: int, time_elapsed: float) -> bool: def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations: if self.max_iterations is not None and iterations >= self.max_iterations:
return False return False
if ( return self.max_execution_time is None or time_elapsed < self.max_execution_time
self.max_execution_time is not None
and time_elapsed >= self.max_execution_time
):
return False
return True
def _return( def _return(
self, self,
@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
return return
actions: list[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): actions = [output] if isinstance(output, AgentAction) else output
actions = [output]
else:
actions = output
for agent_action in actions: for agent_action in actions:
yield agent_action yield agent_action
for agent_action in actions: for agent_action in actions:
@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
return return
actions: list[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): actions = [output] if isinstance(output, AgentAction) else output
actions = [output]
else:
actions = output
for agent_action in actions: for agent_action in actions:
yield agent_action yield agent_action
@ -1728,8 +1710,10 @@ class AgentExecutor(Chain):
if len(self._action_agent.return_values) > 0: if len(self._action_agent.return_values) > 0:
return_value_key = self._action_agent.return_values[0] return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False. # Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map: if (
if name_to_tool_map[agent_action.tool].return_direct: agent_action.tool in name_to_tool_map
and name_to_tool_map[agent_action.tool].return_direct
):
return AgentFinish( return AgentFinish(
{return_value_key: observation}, {return_value_key: observation},
"", "",

View File

@ -1,5 +1,6 @@
"""Load agent.""" """Load agent."""
import contextlib
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
@ -81,11 +82,9 @@ def initialize_agent(
agent_obj = load_agent( agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager agent_path, llm=llm, tools=tools, callback_manager=callback_manager
) )
try: with contextlib.suppress(NotImplementedError):
# TODO: Add tags from the serialized object directly. # TODO: Add tags from the serialized object directly.
tags_.append(agent_obj._agent_type) tags_.append(agent_obj._agent_type)
except NotImplementedError:
pass
else: else:
msg = ( msg = (
"Somehow both `agent` and `agent_path` are None, this should never happen." "Somehow both `agent` and `agent_path` are None, this should never happen."

View File

@ -128,10 +128,7 @@ def _load_agent_from_file(
"""Load agent from file.""" """Load agent from file."""
valid_suffixes = {"json", "yaml"} valid_suffixes = {"json", "yaml"}
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix[1:] == "json": if file_path.suffix[1:] == "json":
with open(file_path) as f: with open(file_path) as f:

View File

@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: messages = [system_message] if system_message else []
messages = [system_message]
else:
messages = []
messages.extend( messages.extend(
[ [

View File

@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: messages = [system_message] if system_message else []
messages = [system_message]
else:
messages = []
messages.extend( messages.extend(
[ [

View File

@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
# Open AI does not support passing in a JSON array as an argument. # Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"] function_name = tool_call["name"]
_tool_input = tool_call["args"] _tool_input = tool_call["args"]
if "__arg1" in _tool_input: tool_input = _tool_input.get("__arg1", _tool_input)
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = f"responded: {message.content}\n" if message.content else "\n" content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"

View File

@ -1,6 +1,7 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import builtins import builtins
import contextlib
import inspect import inspect
import json import json
import logging import logging
@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
# -> {"_type": "foo", "verbose": False, ...} # -> {"_type": "foo", "verbose": False, ...}
""" """
_dict = super().dict(**kwargs) _dict = super().dict(**kwargs)
try: with contextlib.suppress(NotImplementedError):
_dict["_type"] = self._chain_type _dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
raise NotImplementedError(msg) raise NotImplementedError(msg)
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
"""Load chain from file.""" """Load chain from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix == ".json": if file_path.suffix == ".json":
with open(file_path) as f: with open(file_path) as f:

View File

@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
return [self.output_key] return [self.output_key]
def _moderate(self, text: str, results: Any) -> str: def _moderate(self, text: str, results: Any) -> str:
if self.openai_pre_1_0: condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
condition = results["flagged"]
else:
condition = results.flagged
if condition: if condition:
error_str = "Text was found that violates OpenAI's content policy." error_str = "Text was found that violates OpenAI's content policy."
if self.error: if self.error:

View File

@ -108,8 +108,10 @@ class QueryTransformer(Transformer):
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator): if func_name in set(Comparator):
if self.allowed_comparators is not None: if (
if func_name not in self.allowed_comparators: self.allowed_comparators is not None
and func_name not in self.allowed_comparators
):
msg = ( msg = (
f"Received disallowed comparator {func_name}. Allowed " f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}" f"comparators are {self.allowed_comparators}"
@ -117,8 +119,10 @@ class QueryTransformer(Transformer):
raise ValueError(msg) raise ValueError(msg)
return Comparator(func_name) return Comparator(func_name)
elif func_name in set(Operator): elif func_name in set(Operator):
if self.allowed_operators is not None: if (
if func_name not in self.allowed_operators: self.allowed_operators is not None
and func_name not in self.allowed_operators
):
msg = ( msg = (
f"Received disallowed operator {func_name}. Allowed operators" f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}" f" are {self.allowed_operators}"

View File

@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
if not isinstance(llm, BaseChatModel): if not isinstance(llm, BaseChatModel):
msg = "Only chat models supported by the current trajectory eval" msg = "Only chat models supported by the current trajectory eval"
raise NotImplementedError(msg) raise NotImplementedError(msg)
if agent_tools: prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
prompt = EVAL_CHAT_PROMPT
else:
prompt = TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt) eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls( return cls(
agent_tools=agent_tools, # type: ignore[arg-type] agent_tools=agent_tools, # type: ignore[arg-type]

View File

@ -36,8 +36,7 @@ class CombinedMemory(BaseMemory):
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]: def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist.""" """Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value: for val in value:
if isinstance(val, BaseChatMemory): if isinstance(val, BaseChatMemory) and val.input_key is None:
if val.input_key is None:
warnings.warn( warnings.warn(
"When using CombinedMemory, " "When using CombinedMemory, "
"input keys should be so the input is known. " "input keys should be so the input is known. "

View File

@ -51,9 +51,7 @@ class ModelLaboratory:
"Currently only support chains with one output variable, " "Currently only support chains with one output variable, "
f"got {chain.output_keys}" f"got {chain.output_keys}"
) )
raise ValueError(msg) if names is not None and len(names) != len(chains):
if names is not None:
if len(names) != len(chains):
msg = "Length of chains does not match length of names." msg = "Length of chains does not match length of names."
raise ValueError(msg) raise ValueError(msg)
self.chains = chains self.chains = chains
@ -93,10 +91,7 @@ class ModelLaboratory:
""" """
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201 print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
if self.names is not None: name = self.names[i] if self.names is not None else str(chain)
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n") print_text(name, end="\n")
output = chain.run(text) output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n") print_text(output, color=self.chain_colors[str(i)], end="\n\n")

View File

@ -10,8 +10,7 @@ def load_output_parser(config: dict) -> dict:
Returns: Returns:
config dict with output parser loaded config dict with output parser loaded
""" """
if "output_parsers" in config: if "output_parsers" in config and config["output_parsers"] is not None:
if config["output_parsers"] is not None:
_config = config["output_parsers"] _config = config["output_parsers"]
output_parser_type = _config["_type"] output_parser_type = _config["_type"]
if output_parser_type == "regex_parser": if output_parser_type == "regex_parser":

View File

@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
try: try:
# Greedy search for 1st yaml candidate. # Greedy search for 1st yaml candidate.
match = re.search(self.pattern, text.strip()) match = re.search(self.pattern, text.strip())
yaml_str = ""
if match:
yaml_str = match.group("yaml")
else:
# If no backticks were present, try to parse the entire output as yaml. # If no backticks were present, try to parse the entire output as yaml.
yaml_str = text yaml_str = match.group("yaml") if match else text
json_object = yaml.safe_load(yaml_str) json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"): if hasattr(self.pydantic_object, "model_validate"):

View File

@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
response = await self.llm_chain.ainvoke( response = await self.llm_chain.ainvoke(
{"question": question}, config={"callbacks": run_manager.get_child()} {"question": question}, config={"callbacks": run_manager.get_child()}
) )
if isinstance(self.llm_chain, LLMChain): lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
lines = response["text"]
else:
lines = response
if self.verbose: if self.verbose:
logger.info(f"Generated queries: {lines}") logger.info(f"Generated queries: {lines}")
return lines return lines
@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
response = self.llm_chain.invoke( response = self.llm_chain.invoke(
{"question": question}, config={"callbacks": run_manager.get_child()} {"question": question}, config={"callbacks": run_manager.get_child()}
) )
if isinstance(self.llm_chain, LLMChain): lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
lines = response["text"]
else:
lines = response
if self.verbose: if self.verbose:
logger.info(f"Generated queries: {lines}") logger.info(f"Generated queries: {lines}")
return lines return lines

View File

@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
""" """
kwargs = {} kwargs = {}
for field, val in self: for field, val in self:
if field == "evaluator_type": if field == "evaluator_type" or val is None:
continue
elif val is None:
continue continue
kwargs[field] = val kwargs[field] = val
return kwargs return kwargs

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint] [tool.ruff.lint]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"] select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google" pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true pyupgrade.keep-runtime-typing = true

View File

@ -2,6 +2,8 @@
import asyncio import asyncio
import json import json
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
assert messages != [] assert messages != []
# Aggregate state # Aggregate state
run_log = None run_log = reduce(operator.add, results)
for result in results:
if run_log is None:
run_log = result
else:
# `+` is defined for RunLogPatch
run_log = run_log + result
assert isinstance(run_log, RunLog) assert isinstance(run_log, RunLog)

View File

@ -17,7 +17,7 @@ def test_simple_memory() -> None:
output = memory.load_memory_variables({}) output = memory.load_memory_variables({})
assert output == {"baz": "foo"} assert output == {"baz": "foo"}
assert ["baz"] == memory.memory_variables assert memory.memory_variables == ["baz"]
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,5 +1,6 @@
"""Embeddings tests.""" """Embeddings tests."""
import contextlib
import hashlib import hashlib
import importlib import importlib
import warnings import warnings
@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None: def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2 # "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try: with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts) cache_embeddings_batch.embed_documents(texts)
except ValueError:
pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys()) keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist # only the first batch of three embeddings should exist
assert len(keys) == 3 assert len(keys) == 3
@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
) -> None: ) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2 # "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try: with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts) await cache_embeddings_batch.aembed_documents(texts)
except ValueError:
pass
keys = [ keys = [
key key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys() async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()

View File

@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
reference = "I like apples." reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference) result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result assert "score" in result
assert 0 < result["score"] assert result["score"] > 0
if normalize_score: if normalize_score:
assert result["score"] < 1.0 assert result["score"] < 1.0

View File

@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
] ]
combining_parser = CombiningOutputParser(parsers=parsers) combining_parser = CombiningOutputParser(parsers=parsers)
result_dict = combining_parser.parse(DEF_README) result_dict = combining_parser.parse(DEF_README)
assert DEF_EXPECTED_RESULT == result_dict assert result_dict == DEF_EXPECTED_RESULT
def test_combining_output_parser_output_type() -> None: def test_combining_output_parser_output_type() -> None:

View File

@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
def test_pandas_output_parser_col_multi_elem() -> None: def test_pandas_output_parser_col_multi_elem() -> None:
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")} expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
actual_output = parser.parse("column:chicken[0, 1]") actual_output = parser.parse("column:chicken[0, 1]")
for key in actual_output.keys(): for key in actual_output:
assert expected_output["chicken"].equals(actual_output[key]) assert expected_output["chicken"].equals(actual_output[key])

View File

@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
output_keys=["confidence", "explanation"], output_keys=["confidence", "explanation"],
default_output_key="noConfidence", default_output_key="noConfidence",
) )
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README) assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
def test_regex_parser_output_type() -> None: def test_regex_parser_output_type() -> None:

View File

@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
) )
result_dict = regex_dict_parser.parse(DEF_README) result_dict = regex_dict_parser.parse(DEF_README)
print("parse_result:", result_dict) # noqa: T201 print("parse_result:", result_dict) # noqa: T201
assert DEF_EXPECTED_RESULT == result_dict assert result_dict == DEF_EXPECTED_RESULT
def test_regex_dict_output_type() -> None: def test_regex_dict_output_type() -> None:

View File

@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
model = yaml_parser.parse(result) model = yaml_parser.parse(result)
print("parse_result:", result) # noqa: T201 print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == model assert model == DEF_EXPECTED_RESULT
def test_yaml_output_parser_fail() -> None: def test_yaml_output_parser_fail() -> None:

View File

@ -124,8 +124,11 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.Assign): if isinstance(node, ast.Assign):
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP": if (
if isinstance(node.value, ast.Dict): isinstance(target, ast.Name)
and target.id == "DEPRECATED_LOOKUP"
and isinstance(node.value, ast.Dict)
):
return _dict_from_ast(node.value) return _dict_from_ast(node.value)
return None return None
@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
Returns: Returns:
str: The corresponding string value. str: The corresponding string value.
""" """
if isinstance(node, ast.Constant): # Python 3.8+ if (
if isinstance(node.value, str): isinstance(node, ast.Constant) # Python 3.8+
and isinstance(node.value, str)
):
return node.value return node.value
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}" msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
raise AssertionError(msg) raise AssertionError(msg)