mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
langchain: Add ruff rules SIM (#31881)
See https://docs.astral.sh/ruff/rules/#flake8-simplify-sim Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
f06380516f
commit
2df3fdf40d
@ -84,28 +84,28 @@ def create_importer(
|
|||||||
not is_interactive_env()
|
not is_interactive_env()
|
||||||
and deprecated_lookups
|
and deprecated_lookups
|
||||||
and name in deprecated_lookups
|
and name in deprecated_lookups
|
||||||
):
|
|
||||||
# Depth 3:
|
# Depth 3:
|
||||||
# internal.py
|
# internal.py
|
||||||
# module_import.py
|
# module_import.py
|
||||||
# Module in langchain that uses this function
|
# Module in langchain that uses this function
|
||||||
# [calling code] whose frame we want to inspect.
|
# [calling code] whose frame we want to inspect.
|
||||||
if not internal.is_caller_internal(depth=3):
|
and not internal.is_caller_internal(depth=3)
|
||||||
warn_deprecated(
|
):
|
||||||
since="0.1",
|
warn_deprecated(
|
||||||
pending=False,
|
since="0.1",
|
||||||
removal="1.0",
|
pending=False,
|
||||||
message=(
|
removal="1.0",
|
||||||
f"Importing {name} from {package} is deprecated. "
|
message=(
|
||||||
f"Please replace deprecated imports:\n\n"
|
f"Importing {name} from {package} is deprecated. "
|
||||||
f">> from {package} import {name}\n\n"
|
f"Please replace deprecated imports:\n\n"
|
||||||
"with new imports of:\n\n"
|
f">> from {package} import {name}\n\n"
|
||||||
f">> from {new_module} import {name}\n"
|
"with new imports of:\n\n"
|
||||||
"You can use the langchain cli to **automatically** "
|
f">> from {new_module} import {name}\n"
|
||||||
"upgrade many imports. Please see documentation here "
|
"You can use the langchain cli to **automatically** "
|
||||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
"upgrade many imports. Please see documentation here "
|
||||||
),
|
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||||
)
|
),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"module {new_module} has no attribute {name}"
|
msg = f"module {new_module} has no attribute {name}"
|
||||||
@ -115,28 +115,30 @@ def create_importer(
|
|||||||
try:
|
try:
|
||||||
module = importlib.import_module(fallback_module)
|
module = importlib.import_module(fallback_module)
|
||||||
result = getattr(module, name)
|
result = getattr(module, name)
|
||||||
if not is_interactive_env():
|
if (
|
||||||
|
not is_interactive_env()
|
||||||
# Depth 3:
|
# Depth 3:
|
||||||
# internal.py
|
# internal.py
|
||||||
# module_import.py
|
# module_import.py
|
||||||
# Module in langchain that uses this function
|
# Module in langchain that uses this function
|
||||||
# [calling code] whose frame we want to inspect.
|
# [calling code] whose frame we want to inspect.
|
||||||
if not internal.is_caller_internal(depth=3):
|
and not internal.is_caller_internal(depth=3)
|
||||||
warn_deprecated(
|
):
|
||||||
since="0.1",
|
warn_deprecated(
|
||||||
pending=False,
|
since="0.1",
|
||||||
removal="1.0",
|
pending=False,
|
||||||
message=(
|
removal="1.0",
|
||||||
f"Importing {name} from {package} is deprecated. "
|
message=(
|
||||||
f"Please replace deprecated imports:\n\n"
|
f"Importing {name} from {package} is deprecated. "
|
||||||
f">> from {package} import {name}\n\n"
|
f"Please replace deprecated imports:\n\n"
|
||||||
"with new imports of:\n\n"
|
f">> from {package} import {name}\n\n"
|
||||||
f">> from {fallback_module} import {name}\n"
|
"with new imports of:\n\n"
|
||||||
"You can use the langchain cli to **automatically** "
|
f">> from {fallback_module} import {name}\n"
|
||||||
"upgrade many imports. Please see documentation here "
|
"You can use the langchain cli to **automatically** "
|
||||||
"<https://python.langchain.com/docs/versions/v0_2/>"
|
"upgrade many imports. Please see documentation here "
|
||||||
),
|
"<https://python.langchain.com/docs/versions/v0_2/>"
|
||||||
)
|
),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import builtins
|
import builtins
|
||||||
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -196,10 +197,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
agent.agent.save(file_path="path/agent.yaml")
|
agent.agent.save(file_path="path/agent.yaml")
|
||||||
"""
|
"""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
@ -322,10 +320,8 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().model_dump()
|
_dict = super().model_dump()
|
||||||
try:
|
with contextlib.suppress(NotImplementedError):
|
||||||
_dict["_type"] = str(self._agent_type)
|
_dict["_type"] = str(self._agent_type)
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -345,10 +341,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
agent.agent.save(file_path="path/agent.yaml")
|
agent.agent.save(file_path="path/agent.yaml")
|
||||||
"""
|
"""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
# Fetch dictionary to save
|
# Fetch dictionary to save
|
||||||
agent_dict = self.dict()
|
agent_dict = self.dict()
|
||||||
@ -1133,13 +1126,14 @@ class AgentExecutor(Chain):
|
|||||||
agent = self.agent
|
agent = self.agent
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
allowed_tools = agent.get_allowed_tools() # type: ignore[union-attr]
|
||||||
if allowed_tools is not None:
|
if allowed_tools is not None and set(allowed_tools) != set(
|
||||||
if set(allowed_tools) != set([tool.name for tool in tools]):
|
[tool.name for tool in tools]
|
||||||
msg = (
|
):
|
||||||
f"Allowed tools ({allowed_tools}) different than "
|
msg = (
|
||||||
f"provided tools ({[tool.name for tool in tools]})"
|
f"Allowed tools ({allowed_tools}) different than "
|
||||||
)
|
f"provided tools ({[tool.name for tool in tools]})"
|
||||||
raise ValueError(msg)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@ -1272,13 +1266,7 @@ class AgentExecutor(Chain):
|
|||||||
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
def _should_continue(self, iterations: int, time_elapsed: float) -> bool:
|
||||||
if self.max_iterations is not None and iterations >= self.max_iterations:
|
if self.max_iterations is not None and iterations >= self.max_iterations:
|
||||||
return False
|
return False
|
||||||
if (
|
return self.max_execution_time is None or time_elapsed < self.max_execution_time
|
||||||
self.max_execution_time is not None
|
|
||||||
and time_elapsed >= self.max_execution_time
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _return(
|
def _return(
|
||||||
self,
|
self,
|
||||||
@ -1410,10 +1398,7 @@ class AgentExecutor(Chain):
|
|||||||
return
|
return
|
||||||
|
|
||||||
actions: list[AgentAction]
|
actions: list[AgentAction]
|
||||||
if isinstance(output, AgentAction):
|
actions = [output] if isinstance(output, AgentAction) else output
|
||||||
actions = [output]
|
|
||||||
else:
|
|
||||||
actions = output
|
|
||||||
for agent_action in actions:
|
for agent_action in actions:
|
||||||
yield agent_action
|
yield agent_action
|
||||||
for agent_action in actions:
|
for agent_action in actions:
|
||||||
@ -1547,10 +1532,7 @@ class AgentExecutor(Chain):
|
|||||||
return
|
return
|
||||||
|
|
||||||
actions: list[AgentAction]
|
actions: list[AgentAction]
|
||||||
if isinstance(output, AgentAction):
|
actions = [output] if isinstance(output, AgentAction) else output
|
||||||
actions = [output]
|
|
||||||
else:
|
|
||||||
actions = output
|
|
||||||
for agent_action in actions:
|
for agent_action in actions:
|
||||||
yield agent_action
|
yield agent_action
|
||||||
|
|
||||||
@ -1728,12 +1710,14 @@ class AgentExecutor(Chain):
|
|||||||
if len(self._action_agent.return_values) > 0:
|
if len(self._action_agent.return_values) > 0:
|
||||||
return_value_key = self._action_agent.return_values[0]
|
return_value_key = self._action_agent.return_values[0]
|
||||||
# Invalid tools won't be in the map, so we return False.
|
# Invalid tools won't be in the map, so we return False.
|
||||||
if agent_action.tool in name_to_tool_map:
|
if (
|
||||||
if name_to_tool_map[agent_action.tool].return_direct:
|
agent_action.tool in name_to_tool_map
|
||||||
return AgentFinish(
|
and name_to_tool_map[agent_action.tool].return_direct
|
||||||
{return_value_key: observation},
|
):
|
||||||
"",
|
return AgentFinish(
|
||||||
)
|
{return_value_key: observation},
|
||||||
|
"",
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_intermediate_steps(
|
def _prepare_intermediate_steps(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Load agent."""
|
"""Load agent."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
@ -81,11 +82,9 @@ def initialize_agent(
|
|||||||
agent_obj = load_agent(
|
agent_obj = load_agent(
|
||||||
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
|
agent_path, llm=llm, tools=tools, callback_manager=callback_manager
|
||||||
)
|
)
|
||||||
try:
|
with contextlib.suppress(NotImplementedError):
|
||||||
# TODO: Add tags from the serialized object directly.
|
# TODO: Add tags from the serialized object directly.
|
||||||
tags_.append(agent_obj._agent_type)
|
tags_.append(agent_obj._agent_type)
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
"Somehow both `agent` and `agent_path` are None, this should never happen."
|
"Somehow both `agent` and `agent_path` are None, this should never happen."
|
||||||
|
@ -128,10 +128,7 @@ def _load_agent_from_file(
|
|||||||
"""Load agent from file."""
|
"""Load agent from file."""
|
||||||
valid_suffixes = {"json", "yaml"}
|
valid_suffixes = {"json", "yaml"}
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file, str):
|
file_path = Path(file) if isinstance(file, str) else file
|
||||||
file_path = Path(file)
|
|
||||||
else:
|
|
||||||
file_path = file
|
|
||||||
# Load from either json or yaml.
|
# Load from either json or yaml.
|
||||||
if file_path.suffix[1:] == "json":
|
if file_path.suffix[1:] == "json":
|
||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
|
@ -230,10 +230,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
_prompts = extra_prompt_messages or []
|
_prompts = extra_prompt_messages or []
|
||||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
if system_message:
|
messages = [system_message] if system_message else []
|
||||||
messages = [system_message]
|
|
||||||
else:
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
messages.extend(
|
messages.extend(
|
||||||
[
|
[
|
||||||
|
@ -278,10 +278,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
"""
|
"""
|
||||||
_prompts = extra_prompt_messages or []
|
_prompts = extra_prompt_messages or []
|
||||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
if system_message:
|
messages = [system_message] if system_message else []
|
||||||
messages = [system_message]
|
|
||||||
else:
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
messages.extend(
|
messages.extend(
|
||||||
[
|
[
|
||||||
|
@ -60,10 +60,7 @@ def parse_ai_message_to_tool_action(
|
|||||||
# Open AI does not support passing in a JSON array as an argument.
|
# Open AI does not support passing in a JSON array as an argument.
|
||||||
function_name = tool_call["name"]
|
function_name = tool_call["name"]
|
||||||
_tool_input = tool_call["args"]
|
_tool_input = tool_call["args"]
|
||||||
if "__arg1" in _tool_input:
|
tool_input = _tool_input.get("__arg1", _tool_input)
|
||||||
tool_input = _tool_input["__arg1"]
|
|
||||||
else:
|
|
||||||
tool_input = _tool_input
|
|
||||||
|
|
||||||
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
content_msg = f"responded: {message.content}\n" if message.content else "\n"
|
||||||
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
log = f"\nInvoking: `{function_name}` with `{tool_input}`\n{content_msg}\n"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -727,10 +728,8 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
# -> {"_type": "foo", "verbose": False, ...}
|
# -> {"_type": "foo", "verbose": False, ...}
|
||||||
"""
|
"""
|
||||||
_dict = super().dict(**kwargs)
|
_dict = super().dict(**kwargs)
|
||||||
try:
|
with contextlib.suppress(NotImplementedError):
|
||||||
_dict["_type"] = self._chain_type
|
_dict["_type"] = self._chain_type
|
||||||
except NotImplementedError:
|
|
||||||
pass
|
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -758,10 +757,7 @@ class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file_path, str):
|
save_path = Path(file_path) if isinstance(file_path, str) else file_path
|
||||||
save_path = Path(file_path)
|
|
||||||
else:
|
|
||||||
save_path = file_path
|
|
||||||
|
|
||||||
directory_path = save_path.parent
|
directory_path = save_path.parent
|
||||||
directory_path.mkdir(parents=True, exist_ok=True)
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -683,10 +683,7 @@ def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
|
|||||||
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
||||||
"""Load chain from file."""
|
"""Load chain from file."""
|
||||||
# Convert file to Path object.
|
# Convert file to Path object.
|
||||||
if isinstance(file, str):
|
file_path = Path(file) if isinstance(file, str) else file
|
||||||
file_path = Path(file)
|
|
||||||
else:
|
|
||||||
file_path = file
|
|
||||||
# Load from either json or yaml.
|
# Load from either json or yaml.
|
||||||
if file_path.suffix == ".json":
|
if file_path.suffix == ".json":
|
||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
|
@ -95,10 +95,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _moderate(self, text: str, results: Any) -> str:
|
def _moderate(self, text: str, results: Any) -> str:
|
||||||
if self.openai_pre_1_0:
|
condition = results["flagged"] if self.openai_pre_1_0 else results.flagged
|
||||||
condition = results["flagged"]
|
|
||||||
else:
|
|
||||||
condition = results.flagged
|
|
||||||
if condition:
|
if condition:
|
||||||
error_str = "Text was found that violates OpenAI's content policy."
|
error_str = "Text was found that violates OpenAI's content policy."
|
||||||
if self.error:
|
if self.error:
|
||||||
|
@ -108,22 +108,26 @@ class QueryTransformer(Transformer):
|
|||||||
|
|
||||||
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
|
||||||
if func_name in set(Comparator):
|
if func_name in set(Comparator):
|
||||||
if self.allowed_comparators is not None:
|
if (
|
||||||
if func_name not in self.allowed_comparators:
|
self.allowed_comparators is not None
|
||||||
msg = (
|
and func_name not in self.allowed_comparators
|
||||||
f"Received disallowed comparator {func_name}. Allowed "
|
):
|
||||||
f"comparators are {self.allowed_comparators}"
|
msg = (
|
||||||
)
|
f"Received disallowed comparator {func_name}. Allowed "
|
||||||
raise ValueError(msg)
|
f"comparators are {self.allowed_comparators}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return Comparator(func_name)
|
return Comparator(func_name)
|
||||||
elif func_name in set(Operator):
|
elif func_name in set(Operator):
|
||||||
if self.allowed_operators is not None:
|
if (
|
||||||
if func_name not in self.allowed_operators:
|
self.allowed_operators is not None
|
||||||
msg = (
|
and func_name not in self.allowed_operators
|
||||||
f"Received disallowed operator {func_name}. Allowed operators"
|
):
|
||||||
f" are {self.allowed_operators}"
|
msg = (
|
||||||
)
|
f"Received disallowed operator {func_name}. Allowed operators"
|
||||||
raise ValueError(msg)
|
f" are {self.allowed_operators}"
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
return Operator(func_name)
|
return Operator(func_name)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -244,10 +244,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
if not isinstance(llm, BaseChatModel):
|
if not isinstance(llm, BaseChatModel):
|
||||||
msg = "Only chat models supported by the current trajectory eval"
|
msg = "Only chat models supported by the current trajectory eval"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
if agent_tools:
|
prompt = EVAL_CHAT_PROMPT if agent_tools else TOOL_FREE_EVAL_CHAT_PROMPT
|
||||||
prompt = EVAL_CHAT_PROMPT
|
|
||||||
else:
|
|
||||||
prompt = TOOL_FREE_EVAL_CHAT_PROMPT
|
|
||||||
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
eval_chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
return cls(
|
return cls(
|
||||||
agent_tools=agent_tools, # type: ignore[arg-type]
|
agent_tools=agent_tools, # type: ignore[arg-type]
|
||||||
|
@ -36,13 +36,12 @@ class CombinedMemory(BaseMemory):
|
|||||||
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
|
def check_input_key(cls, value: list[BaseMemory]) -> list[BaseMemory]:
|
||||||
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
"""Check that if memories are of type BaseChatMemory that input keys exist."""
|
||||||
for val in value:
|
for val in value:
|
||||||
if isinstance(val, BaseChatMemory):
|
if isinstance(val, BaseChatMemory) and val.input_key is None:
|
||||||
if val.input_key is None:
|
warnings.warn(
|
||||||
warnings.warn(
|
"When using CombinedMemory, "
|
||||||
"When using CombinedMemory, "
|
"input keys should be so the input is known. "
|
||||||
"input keys should be so the input is known. "
|
f" Was not set on {val}"
|
||||||
f" Was not set on {val}"
|
)
|
||||||
)
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -51,11 +51,9 @@ class ModelLaboratory:
|
|||||||
"Currently only support chains with one output variable, "
|
"Currently only support chains with one output variable, "
|
||||||
f"got {chain.output_keys}"
|
f"got {chain.output_keys}"
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
if names is not None and len(names) != len(chains):
|
||||||
if names is not None:
|
msg = "Length of chains does not match length of names."
|
||||||
if len(names) != len(chains):
|
raise ValueError(msg)
|
||||||
msg = "Length of chains does not match length of names."
|
|
||||||
raise ValueError(msg)
|
|
||||||
self.chains = chains
|
self.chains = chains
|
||||||
chain_range = [str(i) for i in range(len(self.chains))]
|
chain_range = [str(i) for i in range(len(self.chains))]
|
||||||
self.chain_colors = get_color_mapping(chain_range)
|
self.chain_colors = get_color_mapping(chain_range)
|
||||||
@ -93,10 +91,7 @@ class ModelLaboratory:
|
|||||||
"""
|
"""
|
||||||
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
|
print(f"\033[1mInput:\033[0m\n{text}\n") # noqa: T201
|
||||||
for i, chain in enumerate(self.chains):
|
for i, chain in enumerate(self.chains):
|
||||||
if self.names is not None:
|
name = self.names[i] if self.names is not None else str(chain)
|
||||||
name = self.names[i]
|
|
||||||
else:
|
|
||||||
name = str(chain)
|
|
||||||
print_text(name, end="\n")
|
print_text(name, end="\n")
|
||||||
output = chain.run(text)
|
output = chain.run(text)
|
||||||
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|
print_text(output, color=self.chain_colors[str(i)], end="\n\n")
|
||||||
|
@ -10,14 +10,13 @@ def load_output_parser(config: dict) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
config dict with output parser loaded
|
config dict with output parser loaded
|
||||||
"""
|
"""
|
||||||
if "output_parsers" in config:
|
if "output_parsers" in config and config["output_parsers"] is not None:
|
||||||
if config["output_parsers"] is not None:
|
_config = config["output_parsers"]
|
||||||
_config = config["output_parsers"]
|
output_parser_type = _config["_type"]
|
||||||
output_parser_type = _config["_type"]
|
if output_parser_type == "regex_parser":
|
||||||
if output_parser_type == "regex_parser":
|
output_parser = RegexParser(**_config)
|
||||||
output_parser = RegexParser(**_config)
|
else:
|
||||||
else:
|
msg = f"Unsupported output parser {output_parser_type}"
|
||||||
msg = f"Unsupported output parser {output_parser_type}"
|
raise ValueError(msg)
|
||||||
raise ValueError(msg)
|
config["output_parsers"] = output_parser
|
||||||
config["output_parsers"] = output_parser
|
|
||||||
return config
|
return config
|
||||||
|
@ -27,12 +27,8 @@ class YamlOutputParser(BaseOutputParser[T]):
|
|||||||
try:
|
try:
|
||||||
# Greedy search for 1st yaml candidate.
|
# Greedy search for 1st yaml candidate.
|
||||||
match = re.search(self.pattern, text.strip())
|
match = re.search(self.pattern, text.strip())
|
||||||
yaml_str = ""
|
# If no backticks were present, try to parse the entire output as yaml.
|
||||||
if match:
|
yaml_str = match.group("yaml") if match else text
|
||||||
yaml_str = match.group("yaml")
|
|
||||||
else:
|
|
||||||
# If no backticks were present, try to parse the entire output as yaml.
|
|
||||||
yaml_str = text
|
|
||||||
|
|
||||||
json_object = yaml.safe_load(yaml_str)
|
json_object = yaml.safe_load(yaml_str)
|
||||||
if hasattr(self.pydantic_object, "model_validate"):
|
if hasattr(self.pydantic_object, "model_validate"):
|
||||||
|
@ -123,10 +123,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
response = await self.llm_chain.ainvoke(
|
response = await self.llm_chain.ainvoke(
|
||||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
if isinstance(self.llm_chain, LLMChain):
|
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||||
lines = response["text"]
|
|
||||||
else:
|
|
||||||
lines = response
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated queries: {lines}")
|
logger.info(f"Generated queries: {lines}")
|
||||||
return lines
|
return lines
|
||||||
@ -186,10 +183,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
response = self.llm_chain.invoke(
|
response = self.llm_chain.invoke(
|
||||||
{"question": question}, config={"callbacks": run_manager.get_child()}
|
{"question": question}, config={"callbacks": run_manager.get_child()}
|
||||||
)
|
)
|
||||||
if isinstance(self.llm_chain, LLMChain):
|
lines = response["text"] if isinstance(self.llm_chain, LLMChain) else response
|
||||||
lines = response["text"]
|
|
||||||
else:
|
|
||||||
lines = response
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logger.info(f"Generated queries: {lines}")
|
logger.info(f"Generated queries: {lines}")
|
||||||
return lines
|
return lines
|
||||||
|
@ -57,9 +57,7 @@ class EvalConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for field, val in self:
|
for field, val in self:
|
||||||
if field == "evaluator_type":
|
if field == "evaluator_type" or val is None:
|
||||||
continue
|
|
||||||
elif val is None:
|
|
||||||
continue
|
continue
|
||||||
kwargs[field] = val
|
kwargs[field] = val
|
||||||
return kwargs
|
return kwargs
|
||||||
|
@ -143,7 +143,7 @@ ignore-regex = ".*(Stati Uniti|Tense=Pres).*"
|
|||||||
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "T201", "D", "UP", "S", "W"]
|
select = ["A", "E", "F", "I", "EM", "PGH003", "PIE", "SIM", "T201", "D", "UP", "S", "W"]
|
||||||
pydocstyle.convention = "google"
|
pydocstyle.convention = "google"
|
||||||
pyupgrade.keep-runtime-typing = true
|
pyupgrade.keep-runtime-typing = true
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
@ -520,14 +522,7 @@ async def test_runnable_agent() -> None:
|
|||||||
assert messages != []
|
assert messages != []
|
||||||
|
|
||||||
# Aggregate state
|
# Aggregate state
|
||||||
run_log = None
|
run_log = reduce(operator.add, results)
|
||||||
|
|
||||||
for result in results:
|
|
||||||
if run_log is None:
|
|
||||||
run_log = result
|
|
||||||
else:
|
|
||||||
# `+` is defined for RunLogPatch
|
|
||||||
run_log = run_log + result
|
|
||||||
|
|
||||||
assert isinstance(run_log, RunLog)
|
assert isinstance(run_log, RunLog)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ def test_simple_memory() -> None:
|
|||||||
output = memory.load_memory_variables({})
|
output = memory.load_memory_variables({})
|
||||||
|
|
||||||
assert output == {"baz": "foo"}
|
assert output == {"baz": "foo"}
|
||||||
assert ["baz"] == memory.memory_variables
|
assert memory.memory_variables == ["baz"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Embeddings tests."""
|
"""Embeddings tests."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import warnings
|
import warnings
|
||||||
@ -75,10 +76,8 @@ def test_embed_documents(cache_embeddings: CacheBackedEmbeddings) -> None:
|
|||||||
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
def test_embed_documents_batch(cache_embeddings_batch: CacheBackedEmbeddings) -> None:
|
||||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||||
try:
|
with contextlib.suppress(ValueError):
|
||||||
cache_embeddings_batch.embed_documents(texts)
|
cache_embeddings_batch.embed_documents(texts)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
keys = list(cache_embeddings_batch.document_embedding_store.yield_keys())
|
||||||
# only the first batch of three embeddings should exist
|
# only the first batch of three embeddings should exist
|
||||||
assert len(keys) == 3
|
assert len(keys) == 3
|
||||||
@ -122,10 +121,8 @@ async def test_aembed_documents_batch(
|
|||||||
) -> None:
|
) -> None:
|
||||||
# "RAISE_EXCEPTION" forces a failure in batch 2
|
# "RAISE_EXCEPTION" forces a failure in batch 2
|
||||||
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
texts = ["1", "22", "a", "333", "RAISE_EXCEPTION"]
|
||||||
try:
|
with contextlib.suppress(ValueError):
|
||||||
await cache_embeddings_batch.aembed_documents(texts)
|
await cache_embeddings_batch.aembed_documents(texts)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
keys = [
|
keys = [
|
||||||
key
|
key
|
||||||
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
async for key in cache_embeddings_batch.document_embedding_store.ayield_keys()
|
||||||
|
@ -70,7 +70,7 @@ def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> N
|
|||||||
reference = "I like apples."
|
reference = "I like apples."
|
||||||
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
|
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
|
||||||
assert "score" in result
|
assert "score" in result
|
||||||
assert 0 < result["score"]
|
assert result["score"] > 0
|
||||||
if normalize_score:
|
if normalize_score:
|
||||||
assert result["score"] < 1.0
|
assert result["score"] < 1.0
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ def test_combining_dict_result() -> None:
|
|||||||
]
|
]
|
||||||
combining_parser = CombiningOutputParser(parsers=parsers)
|
combining_parser = CombiningOutputParser(parsers=parsers)
|
||||||
result_dict = combining_parser.parse(DEF_README)
|
result_dict = combining_parser.parse(DEF_README)
|
||||||
assert DEF_EXPECTED_RESULT == result_dict
|
assert result_dict == DEF_EXPECTED_RESULT
|
||||||
|
|
||||||
|
|
||||||
def test_combining_output_parser_output_type() -> None:
|
def test_combining_output_parser_output_type() -> None:
|
||||||
|
@ -47,7 +47,7 @@ def test_pandas_output_parser_col_first_elem() -> None:
|
|||||||
def test_pandas_output_parser_col_multi_elem() -> None:
|
def test_pandas_output_parser_col_multi_elem() -> None:
|
||||||
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
|
expected_output = {"chicken": pd.Series([1, 2], name="chicken", dtype="int64")}
|
||||||
actual_output = parser.parse("column:chicken[0, 1]")
|
actual_output = parser.parse("column:chicken[0, 1]")
|
||||||
for key in actual_output.keys():
|
for key in actual_output:
|
||||||
assert expected_output["chicken"].equals(actual_output[key])
|
assert expected_output["chicken"].equals(actual_output[key])
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ def test_regex_parser_parse() -> None:
|
|||||||
output_keys=["confidence", "explanation"],
|
output_keys=["confidence", "explanation"],
|
||||||
default_output_key="noConfidence",
|
default_output_key="noConfidence",
|
||||||
)
|
)
|
||||||
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README)
|
assert parser.parse(DEF_README) == DEF_EXPECTED_RESULT
|
||||||
|
|
||||||
|
|
||||||
def test_regex_parser_output_type() -> None:
|
def test_regex_parser_output_type() -> None:
|
||||||
|
@ -35,7 +35,7 @@ def test_regex_dict_result() -> None:
|
|||||||
)
|
)
|
||||||
result_dict = regex_dict_parser.parse(DEF_README)
|
result_dict = regex_dict_parser.parse(DEF_README)
|
||||||
print("parse_result:", result_dict) # noqa: T201
|
print("parse_result:", result_dict) # noqa: T201
|
||||||
assert DEF_EXPECTED_RESULT == result_dict
|
assert result_dict == DEF_EXPECTED_RESULT
|
||||||
|
|
||||||
|
|
||||||
def test_regex_dict_output_type() -> None:
|
def test_regex_dict_output_type() -> None:
|
||||||
|
@ -76,7 +76,7 @@ def test_yaml_output_parser(result: str) -> None:
|
|||||||
|
|
||||||
model = yaml_parser.parse(result)
|
model = yaml_parser.parse(result)
|
||||||
print("parse_result:", result) # noqa: T201
|
print("parse_result:", result) # noqa: T201
|
||||||
assert DEF_EXPECTED_RESULT == model
|
assert model == DEF_EXPECTED_RESULT
|
||||||
|
|
||||||
|
|
||||||
def test_yaml_output_parser_fail() -> None:
|
def test_yaml_output_parser_fail() -> None:
|
||||||
|
@ -124,9 +124,12 @@ def extract_deprecated_lookup(file_path: str) -> Optional[dict[str, Any]]:
|
|||||||
for node in ast.walk(tree):
|
for node in ast.walk(tree):
|
||||||
if isinstance(node, ast.Assign):
|
if isinstance(node, ast.Assign):
|
||||||
for target in node.targets:
|
for target in node.targets:
|
||||||
if isinstance(target, ast.Name) and target.id == "DEPRECATED_LOOKUP":
|
if (
|
||||||
if isinstance(node.value, ast.Dict):
|
isinstance(target, ast.Name)
|
||||||
return _dict_from_ast(node.value)
|
and target.id == "DEPRECATED_LOOKUP"
|
||||||
|
and isinstance(node.value, ast.Dict)
|
||||||
|
):
|
||||||
|
return _dict_from_ast(node.value)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -156,8 +159,10 @@ def _literal_eval_str(node: ast.AST) -> str:
|
|||||||
Returns:
|
Returns:
|
||||||
str: The corresponding string value.
|
str: The corresponding string value.
|
||||||
"""
|
"""
|
||||||
if isinstance(node, ast.Constant): # Python 3.8+
|
if (
|
||||||
if isinstance(node.value, str):
|
isinstance(node, ast.Constant) # Python 3.8+
|
||||||
return node.value
|
and isinstance(node.value, str)
|
||||||
|
):
|
||||||
|
return node.value
|
||||||
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
|
msg = f"Invalid DEPRECATED_LOOKUP format: expected str, got {type(node).__name__}"
|
||||||
raise AssertionError(msg)
|
raise AssertionError(msg)
|
||||||
|
Loading…
Reference in New Issue
Block a user