mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
langchain: Add ruff rules SIM (#31881)
See https://docs.astral.sh/ruff/rules/#flake8-simplify-sim Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
f06380516f
commit
2df3fdf40d
@ -84,28 +84,28 @@ def create_importer(
|
||||
not is_interactive_env()
|
||||
and deprecated_lookups
|
||||
and name in deprecated_lookups
|
||||
):
|
||||
# Depth 3:
|
||||
# internal.py
|
||||
# module_import.py
|
||||
# Module in langchain that uses this function
|
||||
# [calling code] whose frame we want to inspect.
|
||||
if not internal.is_caller_internal(depth=3):
|
||||
warn_deprecated(
|
||||
since="0.1",
|
||||
pending=False,
|
||||
removal="1.0",
|
||||
message=(
|
||||
f"Importing {name} from {package} is deprecated. "
|
||||
f"Please replace deprecated imports:\n\n"
|
||||
f">> from {package} import {name}\n\n"
|
||||
"with new imports of:\n\n"
|
||||
f">> from {new_module} import {name}\n"
|
||||
"You can use the langchain cli to **automatically** "
|
||||
"upgrade many imports. Please see documentation here "
|
||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||
),
|
||||
)
|
||||
and not internal.is_caller_internal(depth=3)
|
||||
):
|
||||
warn_deprecated(
|
||||
since="0.1",
|
||||
pending=False,
|
||||
removal="1.0",
|
||||
message=(
|
||||
f"Importing {name} from {package} is deprecated. "
|
||||
f"Please replace deprecated imports:\n\n"
|
||||
f">> from {package} import {name}\n\n"
|
||||
"with new imports of:\n\n"
|
||||
f">> from {new_module} import {name}\n"
|
||||
"You can use the langchain cli to **automatically** "
|
||||
"upgrade many imports. Please see documentation here "
|
||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||
),
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
msg = f"module {new_module} has no attribute {name}"
|
||||
@ -115,28 +115,30 @@ def create_importer(
|
||||
try:
|
||||
module = importlib.import_module(fallback_module)
|
||||
result = getattr(module, name)
|
||||
if not is_interactive_env():
|
||||
if (
|
||||
not is_interactive_env()
|
||||
# Depth 3:
|
||||
# internal.py
|
||||
# module_import.py
|
||||
# Module in langchain that uses this function
|
||||
# [calling code] whose frame we want to inspect.
|
||||
if not internal.is_caller_internal(depth=3):
|
||||
warn_deprecated(
|
||||
since="0.1",
|
||||
pending=False,
|
||||
removal="1.0",
|
||||
message=(
|
||||
f"Importing {name} from {package} is deprecated. "
|
||||
f"Please replace deprecated imports:\n\n"
|
||||
f">> from {package} import {name}\n\n"
|
||||
"with new imports of:\n\n"
|
||||
f">> from {fallback_module} import {name}\n"
|
||||
"You can use the langchain cli to **automatically** "
|
||||
"upgrade many imports. Please see documentation here "
|
||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||
),
|
||||
)
|
||||
and not internal.is_caller_internal(depth=3)
|
||||
):
|
||||
warn_deprecated(
|
||||
since="0.1",
|
||||
pending=False,
|
||||
removal="1.0",
|
||||
message=(
|
||||
f"Importing {name} from {package} is deprecated. "
|
||||
f"Please replace deprecated imports:\n\n"
|
||||
f">> from {package} import {name}\n\n"
|
||||
"with new imports of:\n\n"
|
||||
f">> from {fallback_module} import {name}\n"
|
||||
"You can use the langchain cli to **automatically** "
|
||||
"upgrade many imports. Please see documentation here "
|
||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().model_dump()
|
||||
try:
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
_dict["_type"] = str(self._agent_type)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
agent.agent.save(file_path="path/agent.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
|
||||
# Fetch dictionary to save
|
||||
agent_dict = self.dict()
|
||||
@ -1133,13 +1126,14 @@ class AgentExecutor(Chain):
|
||||
agent = self.agent
|
||||
tools = self.tools
|
||||
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
||||
if allowed_tools is not None:
|
||||
if set(allowed_tools) != set([tool.name for tool in tools]):
|
||||
msg = (
|
||||
f"Allowed tools ({allowed_tools}) different than "
|
||||
f"provided tools ({[tool.name for tool in tools]})"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if allowed_tools is not None and set(allowed_tools) != set(
|
||||
[tool.name for tool in tools]
|
||||
):
|
||||
msg = (
|
||||
f"Allowed tools ({allowed_tools}) different than "
|
||||
f"provided tools ({[tool.name for tool in tools]})"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
|
||||
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
||||
if self.max_iterations is not None and iterations >= self.max_iterations:
|
||||
return False
|
||||
if (
|
||||
self.max_execution_time is not None
|
||||
and time_elapsed >= self.max_execution_time
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
return self.max_execution_time is None or time_elapsed < self.max_execution_time
|
||||
|
||||
def _return(
|
||||
self,
|
||||
@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
|
||||
return
|
||||
|
||||
actions: list[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
actions = output
|
||||
actions = [output] if isinstance(output, AgentAction) else output
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
for agent_action in actions:
|
||||
@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
|
||||
return
|
||||
|
||||
actions: list[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
actions = output
|
||||
actions = [output] if isinstance(output, AgentAction) else output
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
|
||||
@ -1728,12 +1710,14 @@ class AgentExecutor(Chain):
|
||||
if len(self._action_agent.return_values) > 0:
|
||||
return_value_key = self._action_agent.return_values[0]
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
if (
|
||||
agent_action.tool in name_to_tool_map
|
||||
and name_to_tool_map[agent_action.tool].return_direct
|
||||
):
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
return None
|
||||
|
||||
def _prepare_intermediate_steps(
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Load agent."""
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -81,11 +82,9 @@ def initialize_agent(
|
||||
agent_obj = load_agent(
|
||||
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
|
||||
)
|
||||
try:
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
# TODO: Add tags from the serialized object directly.
|
||||
tags_.append(agent_obj._agent_type)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
msg = (
|
||||
"Somehow both `agent` and `agent_path` are None, this should never happen."
|
||||
|
@ -128,10 +128,7 @@ def _load_agent_from_file(
|
||||
"""Load agent from file."""
|
||||
valid_suffixes = {"json", "yaml"}
|
||||
# Convert file to Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
file_path = Path(file) if isinstance(file, str) else file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix[1:] == "json":
|
||||
with open(file_path) as f:
|
||||
|
@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
"""
|
||||
_prompts = extra_prompt_messages or []
|
||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
if system_message:
|
||||
messages = [system_message]
|
||||
else:
|
||||
messages = []
|
||||
messages = [system_message] if system_message else []
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
|
@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
"""
|
||||
_prompts = extra_prompt_messages or []
|
||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
if system_message:
|
||||
messages = [system_message]
|
||||
else:
|
||||
messages = []
|
||||
messages = [system_message] if system_message else []
|
||||
|
||||
messages.extend(
|
||||
[
|
||||
|
@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
|
||||
# Open AI does not support passing in a JSON array as an argument.
|
||||
function_name = tool_call["name"]
|
||||
_tool_input = tool_call["args"]
|
||||
if "__arg1" in _tool_input:
|
||||
tool_input = _tool_input["__arg1"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
tool_input = _tool_input.get("__arg1", _tool_input)
|
||||
|
||||
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
import builtins
|
||||
import contextlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
# -> {"_type": "foo", "verbose": False, ...}
|
||||
"""
|
||||
_dict = super().dict(**kwargs)
|
||||
try:
|
||||
with contextlib.suppress(NotImplementedError):
|
||||
_dict["_type"] = self._chain_type
|
||||
except NotImplementedError:
|
||||
pass
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
|
||||
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
||||
"""Load chain from file."""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
file_path = Path(file) if isinstance(file, str) else file
|
||||
# Load from either json or yaml.
|
||||
if file_path.suffix == ".json":
|
||||
with open(file_path) as f:
|
||||
|
@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
|
||||
return [self.output_key]
|
||||
|
||||
def _moderate(self, text: str, results: Any) -> str:
|
||||
if self.openai_pre_1_0:
|
||||
condition = results["flagged"]
|
||||
else:
|
||||
condition = results.flagged
|
||||
condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
|
||||
if condition:
|
||||
error_str = "Text was found that violates OpenAI's content policy."
|
||||
if self.error:
|
||||
|
@ -108,22 +108,26 @@ class QueryTransformer(Transformer):
|
||||
|
||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||
if func_name in set(Comparator):
|
||||
if self.allowed_comparators is not None:
|
||||
if func_name not in self.allowed_comparators:
|
||||
msg = (
|
||||
f"Received disallowed comparator {func_name}. Allowed "
|
||||
f"comparators are {self.allowed_comparators}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
self.allowed_comparators is not None
|
||||
and func_name not in self.allowed_comparators
|
||||
):
|
||||
msg = (
|
||||
f"Received disallowed comparator {func_name}. Allowed "
|
||||
f"comparators are {self.allowed_comparators}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return Comparator(func_name)
|
||||
elif func_name in set(Operator):
|
||||
if self.allowed_operators is not None:
|
||||
if func_name not in self.allowed_operators:
|
||||
msg = (
|
||||
f"Received disallowed operator {func_name}. Allowed operators"
|
||||
f" are {self.allowed_operators}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
self.allowed_operators is not None
|
||||
and func_name not in self.allowed_operators
|
||||
):
|
||||
msg = (
|
||||
f"Received disallowed operator {func_name}. Allowed operators"
|
||||
f" are {self.allowed_operators}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return Operator(func_name)
|
||||
else:
|
||||
msg = (
|
||||
|
@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
if not isinstance(llm, BaseChatModel):
|
||||
msg = "Only chat models supported by the current trajectory eval"
|
||||
raise NotImplementedError(msg)
|
||||
if agent_tools:
|
||||
prompt = EVAL_CHAT_PROMPT
|
||||
else:
|
||||
prompt = TOOL_FREE_EVAL_CHAT_PROMPT
|
||||
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
|
||||
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return cls(
|
||||
agent_tools=agent_tools, # type: ignore[arg-type]
|
||||
|
@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory):
|
||||
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
|
||||
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
||||
for val in value:
|
||||
if isinstance(val, BaseChatMemory):
|
||||
if val.input_key is None:
|
||||
warnings.warn(
|
||||
"When using CombinedMemory, "
|
||||
"input keys should be so the input is known. "
|
||||
f" Was not set on {val}"
|
||||
)
|
||||
if isinstance(val, BaseChatMemory) and val.input_key is None:
|
||||
warnings.warn(
|
||||
"When using CombinedMemory, "
|
||||
"input keys should be so the input is known. "
|
||||
f" Was not set on {val}"
|
||||
)
|
||||
return value
|
||||
|
||||
@property
|
||||
|
@ -51,11 +51,9 @@ class ModelLaboratory:
|
||||
"Currently only support chains with one output variable, "
|
||||
f"got {chain.output_keys}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if names is not None:
|
||||
if len(names) != len(chains):
|
||||
msg = "Length of chains does not match length of names."
|
||||
raise ValueError(msg)
|
||||
if names is not None and len(names) != len(chains):
|
||||
msg = "Length of chains does not match length of names."
|
||||
raise ValueError(msg)
|
||||
self.chains = chains
|
||||
chain_range = [str(i) for i in range(len(self.chains))]
|
||||
self.chain_colors = get_color_mapping(chain_range)
|
||||
@ -93,10 +91,7 @@ class ModelLaboratory:
|
||||
"""
|
||||
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
|
||||
for i, chain in enumerate(self.chains):
|
||||
if self.names is not None:
|
||||
name = self.names[i]
|
||||
else:
|
||||
name = str(chain)
|
||||
name = self.names[i] if self.names is not None else str(chain)
|
||||
print_text(name, end="\n")
|
||||
output = chain.run(text)
|
||||
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|
||||
|
@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict:
|
||||
Returns:
|
||||
config dict with output parser loaded
|
||||
"""
|
||||
if "output_parsers" in config:
|
||||
if config["output_parsers"] is not None:
|
||||
_config = config["output_parsers"]
|
||||
output_parser_type = _config["_type"]
|
||||
if output_parser_type == "regex_parser":
|
||||
output_parser = RegexParser(**_config)
|
||||
else:
|
||||
msg = f"Unsupported output parser {output_parser_type}"
|
||||
raise ValueError(msg)
|
||||
config["output_parsers"] = output_parser
|
||||
if "output_parsers" in config and config["output_parsers"] is not None:
|
||||
_config = config["output_parsers"]
|
||||
output_parser_type = _config["_type"]
|
||||
if output_parser_type == "regex_parser":
|
||||
output_parser = RegexParser(**_config)
|
||||
else:
|
||||
msg = f"Unsupported output parser {output_parser_type}"
|
||||
raise ValueError(msg)
|
||||
config["output_parsers"] = output_parser
|
||||
return config
|
||||
|
@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
|
||||
try:
|
||||
# Greedy search for 1st yaml candidate.
|
||||
match = re.search(self.pattern, text.strip())
|
||||
yaml_str = ""
|
||||
if match:
|
||||
yaml_str = match.group("yaml")
|
||||
else:
|
||||
# If no backticks were present, try to parse the entire output as yaml.
|
||||
yaml_str = text
|
||||
# If no backticks were present, try to parse the entire output as yaml.
|
||||
yaml_str = match.group("yaml") if match else text
|
||||
|
||||
json_object = yaml.safe_load(yaml_str)
|
||||
if hasattr(self.pydantic_object, "model_validate"):
|
||||
|
@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
response = await self.llm_chain.ainvoke(
|
||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
lines = response["text"]
|
||||
else:
|
||||
lines = response
|
||||
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
response = self.llm_chain.invoke(
|
||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||
)
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
lines = response["text"]
|
||||
else:
|
||||
lines = response
|
||||
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||
if self.verbose:
|
||||
logger.info(f"Generated queries: {lines}")
|
||||
return lines
|
||||
|
@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
|
||||
"""
|
||||
kwargs = {}
|
||||
for field, val in self:
|
||||
if field == "evaluator_type":
|
||||
continue
|
||||
elif val is None:
|
||||
if field == "evaluator_type" or val is None:
|
||||
continue
|
||||
kwargs[field] = val
|
||||
return kwargs
|
||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
|
||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||
pydocstyle.convention = "google"
|
||||
pyupgrade.keep-runtime-typing = true
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
|
||||
assert messages != []
|
||||
|
||||
# Aggregate state
|
||||
run_log = None
|
||||
|
||||
for result in results:
|
||||
if run_log is None:
|
||||
run_log = result
|
||||
else:
|
||||
# `+` is defined for RunLogPatch
|
||||
run_log = run_log + result
|
||||
run_log = reduce(operator.add, results)
|
||||
|
||||
assert isinstance(run_log, RunLog)
|
||||
|
||||
|
@ -17,7 +17,7 @@ def test_simple_memory() -> None:
|
||||
output = memory.load_memory_variables({})
|
||||
|
||||
assert output == {"baz": "foo"}
|
||||
assert ["baz"] == memory.memory_variables
|
||||
assert memory.memory_variables == ["baz"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Embeddings tests."""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import importlib
|
||||
import warnings
|
||||
@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
||||
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
cache_embeddings_batch.embed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||
# only the first batch of three embeddings should exist
|
||||
assert len(keys) == 3
|
||||
@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
|
||||
) -> None:
|
||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
await cache_embeddings_batch.aembed_documents(texts)
|
||||
except ValueError:
|
||||
pass
|
||||
keys = [
|
||||
key
|
||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||
|
@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
|
||||
reference = "I like apples."
|
||||
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
|
||||
assert "score" in result
|
||||
assert 0 < result["score"]
|
||||
assert result["score"] > 0
|
||||
if normalize_score:
|
||||
assert result["score"] < 1.0
|
||||
|
||||
|
@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
|
||||
]
|
||||
combining_parser = CombiningOutputParser(parsers=parsers)
|
||||
result_dict = combining_parser.parse(DEF_README)
|
||||
assert DEF_EXPECTED_RESULT == result_dict
|
||||
assert result_dict == DEF_EXPECTED_RESULT
|
||||
|
||||
|
||||
def test_combining_output_parser_output_type() -> None:
|
||||
|
@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
|
||||
def test_pandas_output_parser_col_multi_elem() -> None:
|
||||
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
|
||||
actual_output = parser.parse("column:chicken[0, 1]")
|
||||
for key in actual_output.keys():
|
||||
for key in actual_output:
|
||||
assert expected_output["chicken"].equals(actual_output[key])
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
|
||||
output_keys=["confidence", "explanation"],
|
||||
default_output_key="noConfidence",
|
||||
)
|
||||
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README)
|
||||
assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
|
||||
|
||||
|
||||
def test_regex_parser_output_type() -> None:
|
||||
|
@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
|
||||
)
|
||||
result_dict = regex_dict_parser.parse(DEF_README)
|
||||
print("parse_result:", result_dict) # noqa: T201
|
||||
assert DEF_EXPECTED_RESULT == result_dict
|
||||
assert result_dict == DEF_EXPECTED_RESULT
|
||||
|
||||
|
||||
def test_regex_dict_output_type() -> None:
|
||||
|
@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
|
||||
|
||||
model = yaml_parser.parse(result)
|
||||
print("parse_result:", result) # noqa: T201
|
||||
assert DEF_EXPECTED_RESULT == model
|
||||
assert model == DEF_EXPECTED_RESULT
|
||||
|
||||
|
||||
def test_yaml_output_parser_fail() -> None:
|
||||
|
@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP":
|
||||
if isinstance(node.value, ast.Dict):
|
||||
return _dict_from_ast(node.value)
|
||||
if (
|
||||
isinstance(target, ast.Name)
|
||||
and target.id == "DEPRECATED_LOOKUP"
|
||||
and isinstance(node.value, ast.Dict)
|
||||
):
|
||||
return _dict_from_ast(node.value)
|
||||
return None
|
||||
|
||||
|
||||
@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
|
||||
Returns:
|
||||
str: The corresponding string value.
|
||||
"""
|
||||
if isinstance(node, ast.Constant): # Python 3.8+
|
||||
if isinstance(node.value, str):
|
||||
return node.value
|
||||
if (
|
||||
isinstance(node, ast.Constant) # Python 3.8+
|
||||
and isinstance(node.value, str)
|
||||
):
|
||||
return node.value
|
||||
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
|
||||
raise AssertionError(msg)
|
||||
|
Loading…
Reference in New Issue
Block a user