langchain: Add ruff rules SIM (#31881)

See https://docs.astral.sh/ruff/rules/#flake8-simplify-sim

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-07 16:27:04 +02:00 committed by GitHub
parent f06380516f
commit 2df3fdf40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 138 additions and 196 deletions

View File

@ -84,28 +84,28 @@ def create_importer(
not is_interactive_env() not is_interactive_env()
and deprecated_lookups and deprecated_lookups
and name in deprecated_lookups and name in deprecated_lookups
):
# Depth 3: # Depth 3:
# internal.py # internal.py
# module_import.py # module_import.py
# Module in langchain that uses this function # Module in langchain that uses this function
# [calling code] whose frame we want to inspect. # [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3): and not internal.is_caller_internal(depth=3)
warn_deprecated( ):
since="0.1", warn_deprecated(
pending=False, since="0.1",
removal="1.0", pending=False,
message=( removal="1.0",
f"Importing {name} from {package} is deprecated. " message=(
f"Please replace deprecated imports:\n\n" f"Importing {name} from {package} is deprecated. "
f">> from {package} import {name}\n\n" f"Please replace deprecated imports:\n\n"
"with new imports of:\n\n" f">> from {package} import {name}\n\n"
f">> from {new_module} import {name}\n" "with new imports of:\n\n"
"You can use the langchain cli to **automatically** " f">> from {new_module} import {name}\n"
"upgrade many imports. Please see documentation here " "You can use the langchain cli to **automatically** "
"<https://python.langchain.com/docs/versions/v0_2/>" "upgrade many imports. Please see documentation here "
), "<https://python.langchain.com/docs/versions/v0_2/>"
) ),
)
return result return result
except Exception as e: except Exception as e:
msg = f"module {new_module} has no attribute {name}" msg = f"module {new_module} has no attribute {name}"
@ -115,28 +115,30 @@ def create_importer(
try: try:
module = importlib.import_module(fallback_module) module = importlib.import_module(fallback_module)
result = getattr(module, name) result = getattr(module, name)
if not is_interactive_env(): if (
not is_interactive_env()
# Depth 3: # Depth 3:
# internal.py # internal.py
# module_import.py # module_import.py
# Module in langchain that uses this function # Module in langchain that uses this function
# [calling code] whose frame we want to inspect. # [calling code] whose frame we want to inspect.
if not internal.is_caller_internal(depth=3): and not internal.is_caller_internal(depth=3)
warn_deprecated( ):
since="0.1", warn_deprecated(
pending=False, since="0.1",
removal="1.0", pending=False,
message=( removal="1.0",
f"Importing {name} from {package} is deprecated. " message=(
f"Please replace deprecated imports:\n\n" f"Importing {name} from {package} is deprecated. "
f">> from {package} import {name}\n\n" f"Please replace deprecated imports:\n\n"
"with new imports of:\n\n" f">> from {package} import {name}\n\n"
f">> from {fallback_module} import {name}\n" "with new imports of:\n\n"
"You can use the langchain cli to **automatically** " f">> from {fallback_module} import {name}\n"
"upgrade many imports. Please see documentation here " "You can use the langchain cli to **automatically** "
"<https://python.langchain.com/docs/versions/v0_2/>" "upgrade many imports. Please see documentation here "
), "<https://python.langchain.com/docs/versions/v0_2/>"
) ),
)
return result return result
except Exception as e: except Exception as e:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
import builtins import builtins
import contextlib
import json import json
import logging import logging
import time import time
@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml") agent.agent.save(file_path="path/agent.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)
@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().model_dump() _dict = super().model_dump()
try: with contextlib.suppress(NotImplementedError):
_dict["_type"] = str(self._agent_type) _dict["_type"] = str(self._agent_type)
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
agent.agent.save(file_path="path/agent.yaml") agent.agent.save(file_path="path/agent.yaml")
""" """
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
# Fetch dictionary to save # Fetch dictionary to save
agent_dict = self.dict() agent_dict = self.dict()
@ -1133,13 +1126,14 @@ class AgentExecutor(Chain):
agent = self.agent agent = self.agent
tools = self.tools tools = self.tools
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr] allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
if allowed_tools is not None: if allowed_tools is not None and set(allowed_tools) != set(
if set(allowed_tools) != set([tool.name for tool in tools]): [tool.name for tool in tools]
msg = ( ):
f"Allowed tools ({allowed_tools}) different than " msg = (
f"provided tools ({[tool.name for tool in tools]})" f"Allowed tools ({allowed_tools}) different than "
) f"provided tools ({[tool.name for tool in tools]})"
raise ValueError(msg) )
raise ValueError(msg)
return self return self
@model_validator(mode="before") @model_validator(mode="before")
@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
def _should_continue(self, iterations: int, time_elapsed: float) -> bool: def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
if self.max_iterations is not None and iterations >= self.max_iterations: if self.max_iterations is not None and iterations >= self.max_iterations:
return False return False
if ( return self.max_execution_time is None or time_elapsed < self.max_execution_time
self.max_execution_time is not None
and time_elapsed >= self.max_execution_time
):
return False
return True
def _return( def _return(
self, self,
@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
return return
actions: list[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): actions = [output] if isinstance(output, AgentAction) else output
actions = [output]
else:
actions = output
for agent_action in actions: for agent_action in actions:
yield agent_action yield agent_action
for agent_action in actions: for agent_action in actions:
@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
return return
actions: list[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): actions = [output] if isinstance(output, AgentAction) else output
actions = [output]
else:
actions = output
for agent_action in actions: for agent_action in actions:
yield agent_action yield agent_action
@ -1728,12 +1710,14 @@ class AgentExecutor(Chain):
if len(self._action_agent.return_values) > 0: if len(self._action_agent.return_values) > 0:
return_value_key = self._action_agent.return_values[0] return_value_key = self._action_agent.return_values[0]
# Invalid tools won't be in the map, so we return False. # Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map: if (
if name_to_tool_map[agent_action.tool].return_direct: agent_action.tool in name_to_tool_map
return AgentFinish( and name_to_tool_map[agent_action.tool].return_direct
{return_value_key: observation}, ):
"", return AgentFinish(
) {return_value_key: observation},
"",
)
return None return None
def _prepare_intermediate_steps( def _prepare_intermediate_steps(

View File

@ -1,5 +1,6 @@
"""Load agent.""" """Load agent."""
import contextlib
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
@ -81,11 +82,9 @@ def initialize_agent(
agent_obj = load_agent( agent_obj = load_agent(
agent_path, llm=llm, tools=tools, callback_manager=callback_manager agent_path, llm=llm, tools=tools, callback_manager=callback_manager
) )
try: with contextlib.suppress(NotImplementedError):
# TODO: Add tags from the serialized object directly. # TODO: Add tags from the serialized object directly.
tags_.append(agent_obj._agent_type) tags_.append(agent_obj._agent_type)
except NotImplementedError:
pass
else: else:
msg = ( msg = (
"Somehow both `agent` and `agent_path` are None, this should never happen." "Somehow both `agent` and `agent_path` are None, this should never happen."

View File

@ -128,10 +128,7 @@ def _load_agent_from_file(
"""Load agent from file.""" """Load agent from file."""
valid_suffixes = {"json", "yaml"} valid_suffixes = {"json", "yaml"}
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix[1:] == "json": if file_path.suffix[1:] == "json":
with open(file_path) as f: with open(file_path) as f:

View File

@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: messages = [system_message] if system_message else []
messages = [system_message]
else:
messages = []
messages.extend( messages.extend(
[ [

View File

@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: messages = [system_message] if system_message else []
messages = [system_message]
else:
messages = []
messages.extend( messages.extend(
[ [

View File

@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
# Open AI does not support passing in a JSON array as an argument. # Open AI does not support passing in a JSON array as an argument.
function_name = tool_call["name"] function_name = tool_call["name"]
_tool_input = tool_call["args"] _tool_input = tool_call["args"]
if "__arg1" in _tool_input: tool_input = _tool_input.get("__arg1", _tool_input)
tool_input = _tool_input["__arg1"]
else:
tool_input = _tool_input
content_msg = f"responded: {message.content}\n" if message.content else "\n" content_msg = f"responded: {message.content}\n" if message.content else "\n"
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n" log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"

View File

@ -1,6 +1,7 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import builtins import builtins
import contextlib
import inspect import inspect
import json import json
import logging import logging
@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
# -> {"_type": "foo", "verbose": False, ...} # -> {"_type": "foo", "verbose": False, ...}
""" """
_dict = super().dict(**kwargs) _dict = super().dict(**kwargs)
try: with contextlib.suppress(NotImplementedError):
_dict["_type"] = self._chain_type _dict["_type"] = self._chain_type
except NotImplementedError:
pass
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
raise NotImplementedError(msg) raise NotImplementedError(msg)
# Convert file to Path object. # Convert file to Path object.
if isinstance(file_path, str): save_path = Path(file_path) if isinstance(file_path, str) else file_path
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True) directory_path.mkdir(parents=True, exist_ok=True)

View File

@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
"""Load chain from file.""" """Load chain from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): file_path = Path(file) if isinstance(file, str) else file
file_path = Path(file)
else:
file_path = file
# Load from either json or yaml. # Load from either json or yaml.
if file_path.suffix == ".json": if file_path.suffix == ".json":
with open(file_path) as f: with open(file_path) as f:

View File

@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
return [self.output_key] return [self.output_key]
def _moderate(self, text: str, results: Any) -> str: def _moderate(self, text: str, results: Any) -> str:
if self.openai_pre_1_0: condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
condition = results["flagged"]
else:
condition = results.flagged
if condition: if condition:
error_str = "Text was found that violates OpenAI's content policy." error_str = "Text was found that violates OpenAI's content policy."
if self.error: if self.error:

View File

@ -108,22 +108,26 @@ class QueryTransformer(Transformer):
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator): if func_name in set(Comparator):
if self.allowed_comparators is not None: if (
if func_name not in self.allowed_comparators: self.allowed_comparators is not None
msg = ( and func_name not in self.allowed_comparators
f"Received disallowed comparator {func_name}. Allowed " ):
f"comparators are {self.allowed_comparators}" msg = (
) f"Received disallowed comparator {func_name}. Allowed "
raise ValueError(msg) f"comparators are {self.allowed_comparators}"
)
raise ValueError(msg)
return Comparator(func_name) return Comparator(func_name)
elif func_name in set(Operator): elif func_name in set(Operator):
if self.allowed_operators is not None: if (
if func_name not in self.allowed_operators: self.allowed_operators is not None
msg = ( and func_name not in self.allowed_operators
f"Received disallowed operator {func_name}. Allowed operators" ):
f" are {self.allowed_operators}" msg = (
) f"Received disallowed operator {func_name}. Allowed operators"
raise ValueError(msg) f" are {self.allowed_operators}"
)
raise ValueError(msg)
return Operator(func_name) return Operator(func_name)
else: else:
msg = ( msg = (

View File

@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
if not isinstance(llm, BaseChatModel): if not isinstance(llm, BaseChatModel):
msg = "Only chat models supported by the current trajectory eval" msg = "Only chat models supported by the current trajectory eval"
raise NotImplementedError(msg) raise NotImplementedError(msg)
if agent_tools: prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
prompt = EVAL_CHAT_PROMPT
else:
prompt = TOOL_FREE_EVAL_CHAT_PROMPT
eval_chain = LLMChain(llm=llm, prompt=prompt) eval_chain = LLMChain(llm=llm, prompt=prompt)
return cls( return cls(
agent_tools=agent_tools, # type: ignore[arg-type] agent_tools=agent_tools, # type: ignore[arg-type]

View File

@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory):
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]: def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
"""Check that if memories are of type BaseChatMemory that input keys exist.""" """Check that if memories are of type BaseChatMemory that input keys exist."""
for val in value: for val in value:
if isinstance(val, BaseChatMemory): if isinstance(val, BaseChatMemory) and val.input_key is None:
if val.input_key is None: warnings.warn(
warnings.warn( "When using CombinedMemory, "
"When using CombinedMemory, " "input keys should be so the input is known. "
"input keys should be so the input is known. " f" Was not set on {val}"
f" Was not set on {val}" )
)
return value return value
@property @property

View File

@ -51,11 +51,9 @@ class ModelLaboratory:
"Currently only support chains with one output variable, " "Currently only support chains with one output variable, "
f"got {chain.output_keys}" f"got {chain.output_keys}"
) )
raise ValueError(msg) if names is not None and len(names) != len(chains):
if names is not None: msg = "Length of chains does not match length of names."
if len(names) != len(chains): raise ValueError(msg)
msg = "Length of chains does not match length of names."
raise ValueError(msg)
self.chains = chains self.chains = chains
chain_range = [str(i) for i in range(len(self.chains))] chain_range = [str(i) for i in range(len(self.chains))]
self.chain_colors = get_color_mapping(chain_range) self.chain_colors = get_color_mapping(chain_range)
@ -93,10 +91,7 @@ class ModelLaboratory:
""" """
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201 print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
if self.names is not None: name = self.names[i] if self.names is not None else str(chain)
name = self.names[i]
else:
name = str(chain)
print_text(name, end="\n") print_text(name, end="\n")
output = chain.run(text) output = chain.run(text)
print_text(output, color=self.chain_colors[str(i)], end="\n\n") print_text(output, color=self.chain_colors[str(i)], end="\n\n")

View File

@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict:
Returns: Returns:
config dict with output parser loaded config dict with output parser loaded
""" """
if "output_parsers" in config: if "output_parsers" in config and config["output_parsers"] is not None:
if config["output_parsers"] is not None: _config = config["output_parsers"]
_config = config["output_parsers"] output_parser_type = _config["_type"]
output_parser_type = _config["_type"] if output_parser_type == "regex_parser":
if output_parser_type == "regex_parser": output_parser = RegexParser(**_config)
output_parser = RegexParser(**_config) else:
else: msg = f"Unsupported output parser {output_parser_type}"
msg = f"Unsupported output parser {output_parser_type}" raise ValueError(msg)
raise ValueError(msg) config["output_parsers"] = output_parser
config["output_parsers"] = output_parser
return config return config

View File

@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
try: try:
# Greedy search for 1st yaml candidate. # Greedy search for 1st yaml candidate.
match = re.search(self.pattern, text.strip()) match = re.search(self.pattern, text.strip())
yaml_str = "" # If no backticks were present, try to parse the entire output as yaml.
if match: yaml_str = match.group("yaml") if match else text
yaml_str = match.group("yaml")
else:
# If no backticks were present, try to parse the entire output as yaml.
yaml_str = text
json_object = yaml.safe_load(yaml_str) json_object = yaml.safe_load(yaml_str)
if hasattr(self.pydantic_object, "model_validate"): if hasattr(self.pydantic_object, "model_validate"):

View File

@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
response = await self.llm_chain.ainvoke( response = await self.llm_chain.ainvoke(
{"question": question}, config={"callbacks": run_manager.get_child()} {"question": question}, config={"callbacks": run_manager.get_child()}
) )
if isinstance(self.llm_chain, LLMChain): lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
lines = response["text"]
else:
lines = response
if self.verbose: if self.verbose:
logger.info(f"Generated queries: {lines}") logger.info(f"Generated queries: {lines}")
return lines return lines
@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
response = self.llm_chain.invoke( response = self.llm_chain.invoke(
{"question": question}, config={"callbacks": run_manager.get_child()} {"question": question}, config={"callbacks": run_manager.get_child()}
) )
if isinstance(self.llm_chain, LLMChain): lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
lines = response["text"]
else:
lines = response
if self.verbose: if self.verbose:
logger.info(f"Generated queries: {lines}") logger.info(f"Generated queries: {lines}")
return lines return lines

View File

@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
""" """
kwargs = {} kwargs = {}
for field, val in self: for field, val in self:
if field == "evaluator_type": if field == "evaluator_type" or val is None:
continue
elif val is None:
continue continue
kwargs[field] = val kwargs[field] = val
return kwargs return kwargs

View File

@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin" ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
[tool.ruff.lint] [tool.ruff.lint]
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"] select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
pydocstyle.convention = "google" pydocstyle.convention = "google"
pyupgrade.keep-runtime-typing = true pyupgrade.keep-runtime-typing = true

View File

@ -2,6 +2,8 @@
import asyncio import asyncio
import json import json
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
assert messages != [] assert messages != []
# Aggregate state # Aggregate state
run_log = None run_log = reduce(operator.add, results)
for result in results:
if run_log is None:
run_log = result
else:
# `+` is defined for RunLogPatch
run_log = run_log + result
assert isinstance(run_log, RunLog) assert isinstance(run_log, RunLog)

View File

@ -17,7 +17,7 @@ def test_simple_memory() -> None:
output = memory.load_memory_variables({}) output = memory.load_memory_variables({})
assert output == {"baz": "foo"} assert output == {"baz": "foo"}
assert ["baz"] == memory.memory_variables assert memory.memory_variables == ["baz"]
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,5 +1,6 @@
"""Embeddings tests.""" """Embeddings tests."""
import contextlib
import hashlib import hashlib
import importlib import importlib
import warnings import warnings
@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None: def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2 # "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try: with contextlib.suppress(ValueError):
cache_embeddings_batch.embed_documents(texts) cache_embeddings_batch.embed_documents(texts)
except ValueError:
pass
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys()) keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
# only the first batch of three embeddings should exist # only the first batch of three embeddings should exist
assert len(keys) == 3 assert len(keys) == 3
@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
) -> None: ) -> None:
# "RAISE_EXCEPTION" forces a failure in batch 2 # "RAISE_EXCEPTION" forces a failure in batch 2
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"] texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
try: with contextlib.suppress(ValueError):
await cache_embeddings_batch.aembed_documents(texts) await cache_embeddings_batch.aembed_documents(texts)
except ValueError:
pass
keys = [ keys = [
key key
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys() async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()

View File

@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
reference = "I like apples." reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference) result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result assert "score" in result
assert 0 < result["score"] assert result["score"] > 0
if normalize_score: if normalize_score:
assert result["score"] < 1.0 assert result["score"] < 1.0

View File

@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
] ]
combining_parser = CombiningOutputParser(parsers=parsers) combining_parser = CombiningOutputParser(parsers=parsers)
result_dict = combining_parser.parse(DEF_README) result_dict = combining_parser.parse(DEF_README)
assert DEF_EXPECTED_RESULT == result_dict assert result_dict == DEF_EXPECTED_RESULT
def test_combining_output_parser_output_type() -> None: def test_combining_output_parser_output_type() -> None:

View File

@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
def test_pandas_output_parser_col_multi_elem() -> None: def test_pandas_output_parser_col_multi_elem() -> None:
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")} expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
actual_output = parser.parse("column:chicken[0, 1]") actual_output = parser.parse("column:chicken[0, 1]")
for key in actual_output.keys(): for key in actual_output:
assert expected_output["chicken"].equals(actual_output[key]) assert expected_output["chicken"].equals(actual_output[key])

View File

@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
output_keys=["confidence", "explanation"], output_keys=["confidence", "explanation"],
default_output_key="noConfidence", default_output_key="noConfidence",
) )
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README) assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
def test_regex_parser_output_type() -> None: def test_regex_parser_output_type() -> None:

View File

@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
) )
result_dict = regex_dict_parser.parse(DEF_README) result_dict = regex_dict_parser.parse(DEF_README)
print("parse_result:", result_dict) # noqa: T201 print("parse_result:", result_dict) # noqa: T201
assert DEF_EXPECTED_RESULT == result_dict assert result_dict == DEF_EXPECTED_RESULT
def test_regex_dict_output_type() -> None: def test_regex_dict_output_type() -> None:

View File

@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
model = yaml_parser.parse(result) model = yaml_parser.parse(result)
print("parse_result:", result) # noqa: T201 print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == model assert model == DEF_EXPECTED_RESULT
def test_yaml_output_parser_fail() -> None: def test_yaml_output_parser_fail() -> None:

View File

@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.Assign): if isinstance(node, ast.Assign):
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP": if (
if isinstance(node.value, ast.Dict): isinstance(target, ast.Name)
return _dict_from_ast(node.value) and target.id == "DEPRECATED_LOOKUP"
and isinstance(node.value, ast.Dict)
):
return _dict_from_ast(node.value)
return None return None
@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
Returns: Returns:
str: The corresponding string value. str: The corresponding string value.
""" """
if isinstance(node, ast.Constant): # Python 3.8+ if (
if isinstance(node.value, str): isinstance(node, ast.Constant) # Python 3.8+
return node.value and isinstance(node.value, str)
):
return node.value
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}" msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
raise AssertionError(msg) raise AssertionError(msg)