mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
langchain[lint]: use pyupgrade
to get to 3.9 standards (#30782)
This commit is contained in:
parent
d9b628e764
commit
48affc498b
@ -1,5 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from langchain_core._api import internal, warn_deprecated
|
from langchain_core._api import internal, warn_deprecated
|
||||||
|
|
||||||
@ -15,8 +15,8 @@ ALLOWED_TOP_LEVEL_PKGS = {
|
|||||||
def create_importer(
|
def create_importer(
|
||||||
package: str,
|
package: str,
|
||||||
*,
|
*,
|
||||||
module_lookup: Optional[Dict[str, str]] = None,
|
module_lookup: Optional[dict[str, str]] = None,
|
||||||
deprecated_lookups: Optional[Dict[str, str]] = None,
|
deprecated_lookups: Optional[dict[str, str]] = None,
|
||||||
fallback_module: Optional[str] = None,
|
fallback_module: Optional[str] = None,
|
||||||
) -> Callable[[str], Any]:
|
) -> Callable[[str], Any]:
|
||||||
"""Create a function that helps retrieve objects from their new locations.
|
"""Create a function that helps retrieve objects from their new locations.
|
||||||
|
@ -3,21 +3,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import builtins
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -62,17 +58,17 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
"""Base Single Action Agent class."""
|
"""Base Single Action Agent class."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_values(self) -> List[str]:
|
def return_values(self) -> list[str]:
|
||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -91,7 +87,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -109,7 +105,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -118,7 +114,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
@ -171,7 +167,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
"""Return Identifier of an agent type."""
|
"""Return Identifier of an agent type."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent.
|
"""Return dictionary representation of agent.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -223,7 +219,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
def tool_run_logging_kwargs(self) -> Dict:
|
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||||
"""Return logging kwargs for tool run."""
|
"""Return logging kwargs for tool run."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -232,11 +228,11 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
"""Base Multi Action Agent class."""
|
"""Base Multi Action Agent class."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_values(self) -> List[str]:
|
def return_values(self) -> list[str]:
|
||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||||
"""Get allowed tools.
|
"""Get allowed tools.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -247,10 +243,10 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -266,10 +262,10 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Async given input, decided what to do.
|
"""Async given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -284,7 +280,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -293,7 +289,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
@ -323,7 +319,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
"""Return Identifier of an agent type."""
|
"""Return Identifier of an agent type."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().model_dump()
|
_dict = super().model_dump()
|
||||||
try:
|
try:
|
||||||
@ -371,7 +367,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|
||||||
def tool_run_logging_kwargs(self) -> Dict:
|
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||||
"""Return logging kwargs for tool run."""
|
"""Return logging kwargs for tool run."""
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
@ -386,7 +382,7 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
|
|||||||
|
|
||||||
|
|
||||||
class MultiActionAgentOutputParser(
|
class MultiActionAgentOutputParser(
|
||||||
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
BaseOutputParser[Union[list[AgentAction], AgentFinish]]
|
||||||
):
|
):
|
||||||
"""Base class for parsing agent output into agent actions/finish.
|
"""Base class for parsing agent output into agent actions/finish.
|
||||||
|
|
||||||
@ -394,7 +390,7 @@ class MultiActionAgentOutputParser(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Parse text into agent actions/finish.
|
"""Parse text into agent actions/finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -411,8 +407,8 @@ class RunnableAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
|
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
|
||||||
"""Runnable to call to get agent action."""
|
"""Runnable to call to get agent action."""
|
||||||
input_keys_arg: List[str] = []
|
input_keys_arg: list[str] = []
|
||||||
return_keys_arg: List[str] = []
|
return_keys_arg: list[str] = []
|
||||||
stream_runnable: bool = True
|
stream_runnable: bool = True
|
||||||
"""Whether to stream from the runnable or not.
|
"""Whether to stream from the runnable or not.
|
||||||
|
|
||||||
@ -427,18 +423,18 @@ class RunnableAgent(BaseSingleActionAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_values(self) -> List[str]:
|
def return_values(self) -> list[str]:
|
||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return self.return_keys_arg
|
return self.return_keys_arg
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys."""
|
"""Return the input keys."""
|
||||||
return self.input_keys_arg
|
return self.input_keys_arg
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -474,7 +470,7 @@ class RunnableAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
@ -518,10 +514,10 @@ class RunnableAgent(BaseSingleActionAgent):
|
|||||||
class RunnableMultiActionAgent(BaseMultiActionAgent):
|
class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||||
"""Agent powered by Runnables."""
|
"""Agent powered by Runnables."""
|
||||||
|
|
||||||
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
|
runnable: Runnable[dict, Union[list[AgentAction], AgentFinish]]
|
||||||
"""Runnable to call to get agent actions."""
|
"""Runnable to call to get agent actions."""
|
||||||
input_keys_arg: List[str] = []
|
input_keys_arg: list[str] = []
|
||||||
return_keys_arg: List[str] = []
|
return_keys_arg: list[str] = []
|
||||||
stream_runnable: bool = True
|
stream_runnable: bool = True
|
||||||
"""Whether to stream from the runnable or not.
|
"""Whether to stream from the runnable or not.
|
||||||
|
|
||||||
@ -536,12 +532,12 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_values(self) -> List[str]:
|
def return_values(self) -> list[str]:
|
||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return self.return_keys_arg
|
return self.return_keys_arg
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -551,11 +547,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
List[AgentAction],
|
list[AgentAction],
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
]:
|
]:
|
||||||
"""Based on past history and current inputs, decide what to do.
|
"""Based on past history and current inputs, decide what to do.
|
||||||
@ -590,11 +586,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
List[AgentAction],
|
list[AgentAction],
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
]:
|
]:
|
||||||
"""Async based on past history and current inputs, decide what to do.
|
"""Async based on past history and current inputs, decide what to do.
|
||||||
@ -644,11 +640,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
"""LLMChain to use for agent."""
|
"""LLMChain to use for agent."""
|
||||||
output_parser: AgentOutputParser
|
output_parser: AgentOutputParser
|
||||||
"""Output parser to use for agent."""
|
"""Output parser to use for agent."""
|
||||||
stop: List[str]
|
stop: list[str]
|
||||||
"""List of strings to stop on."""
|
"""List of strings to stop on."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -656,7 +652,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
del _dict["output_parser"]
|
del _dict["output_parser"]
|
||||||
@ -664,7 +660,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -689,7 +685,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -712,7 +708,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
)
|
)
|
||||||
return self.output_parser.parse(output)
|
return self.output_parser.parse(output)
|
||||||
|
|
||||||
def tool_run_logging_kwargs(self) -> Dict:
|
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||||
"""Return logging kwargs for tool run."""
|
"""Return logging kwargs for tool run."""
|
||||||
return {
|
return {
|
||||||
"llm_prefix": "",
|
"llm_prefix": "",
|
||||||
@ -737,21 +733,21 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""LLMChain to use for agent."""
|
"""LLMChain to use for agent."""
|
||||||
output_parser: AgentOutputParser
|
output_parser: AgentOutputParser
|
||||||
"""Output parser to use for agent."""
|
"""Output parser to use for agent."""
|
||||||
allowed_tools: Optional[List[str]] = None
|
allowed_tools: Optional[list[str]] = None
|
||||||
"""Allowed tools for the agent. If None, all tools are allowed."""
|
"""Allowed tools for the agent. If None, all tools are allowed."""
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
del _dict["output_parser"]
|
del _dict["output_parser"]
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||||
"""Get allowed tools."""
|
"""Get allowed tools."""
|
||||||
return self.allowed_tools
|
return self.allowed_tools
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def return_values(self) -> List[str]:
|
def return_values(self) -> list[str]:
|
||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
@ -767,15 +763,15 @@ class Agent(BaseSingleActionAgent):
|
|||||||
raise ValueError("fix_text not implemented for this agent.")
|
raise ValueError("fix_text not implemented for this agent.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> list[str]:
|
||||||
return [
|
return [
|
||||||
f"\n{self.observation_prefix.rstrip()}",
|
f"\n{self.observation_prefix.rstrip()}",
|
||||||
f"\n\t{self.observation_prefix.rstrip()}",
|
f"\n\t{self.observation_prefix.rstrip()}",
|
||||||
]
|
]
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> Union[str, List[BaseMessage]]:
|
) -> Union[str, list[BaseMessage]]:
|
||||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||||
thoughts = ""
|
thoughts = ""
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
@ -785,7 +781,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -806,7 +802,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -827,8 +823,8 @@ class Agent(BaseSingleActionAgent):
|
|||||||
return agent_output
|
return agent_output
|
||||||
|
|
||||||
def get_full_inputs(
|
def get_full_inputs(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> builtins.dict[str, Any]:
|
||||||
"""Create the full inputs for the LLMChain from intermediate steps.
|
"""Create the full inputs for the LLMChain from intermediate steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -845,7 +841,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
return full_inputs
|
return full_inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -957,7 +953,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
@ -1009,7 +1005,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
f"got {early_stopping_method}"
|
f"got {early_stopping_method}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def tool_run_logging_kwargs(self) -> Dict:
|
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||||
"""Return logging kwargs for tool run."""
|
"""Return logging kwargs for tool run."""
|
||||||
return {
|
return {
|
||||||
"llm_prefix": self.llm_prefix,
|
"llm_prefix": self.llm_prefix,
|
||||||
@ -1040,7 +1036,7 @@ class ExceptionTool(BaseTool): # type: ignore[override]
|
|||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
NextStepOutput = list[Union[AgentFinish, AgentAction, AgentStep]]
|
||||||
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
||||||
|
|
||||||
|
|
||||||
@ -1086,7 +1082,7 @@ class AgentExecutor(Chain):
|
|||||||
as an observation.
|
as an observation.
|
||||||
"""
|
"""
|
||||||
trim_intermediate_steps: Union[
|
trim_intermediate_steps: Union[
|
||||||
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
|
int, Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]]
|
||||||
] = -1
|
] = -1
|
||||||
"""How to trim the intermediate steps before returning them.
|
"""How to trim the intermediate steps before returning them.
|
||||||
Defaults to -1, which means no trimming.
|
Defaults to -1, which means no trimming.
|
||||||
@ -1144,7 +1140,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_runnable_agent(cls, values: Dict) -> Any:
|
def validate_runnable_agent(cls, values: dict) -> Any:
|
||||||
"""Convert runnable to agent if passed in.
|
"""Convert runnable to agent if passed in.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1160,7 +1156,7 @@ class AgentExecutor(Chain):
|
|||||||
except Exception as _:
|
except Exception as _:
|
||||||
multi_action = False
|
multi_action = False
|
||||||
else:
|
else:
|
||||||
multi_action = output_type == Union[List[AgentAction], AgentFinish]
|
multi_action = output_type == Union[list[AgentAction], AgentFinish]
|
||||||
|
|
||||||
stream_runnable = values.pop("stream_runnable", True)
|
stream_runnable = values.pop("stream_runnable", True)
|
||||||
if multi_action:
|
if multi_action:
|
||||||
@ -1239,7 +1235,7 @@ class AgentExecutor(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the input keys.
|
"""Return the input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -1247,7 +1243,7 @@ class AgentExecutor(Chain):
|
|||||||
return self._action_agent.input_keys
|
return self._action_agent.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return the singular output key.
|
"""Return the singular output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -1284,7 +1280,7 @@ class AgentExecutor(Chain):
|
|||||||
output: AgentFinish,
|
output: AgentFinish,
|
||||||
intermediate_steps: list,
|
intermediate_steps: list,
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
||||||
final_output = output.return_values
|
final_output = output.return_values
|
||||||
@ -1297,7 +1293,7 @@ class AgentExecutor(Chain):
|
|||||||
output: AgentFinish,
|
output: AgentFinish,
|
||||||
intermediate_steps: list,
|
intermediate_steps: list,
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_agent_finish(
|
await run_manager.on_agent_finish(
|
||||||
output, color="green", verbose=self.verbose
|
output, color="green", verbose=self.verbose
|
||||||
@ -1309,7 +1305,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def _consume_next_step(
|
def _consume_next_step(
|
||||||
self, values: NextStepOutput
|
self, values: NextStepOutput
|
||||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||||
if isinstance(values[-1], AgentFinish):
|
if isinstance(values[-1], AgentFinish):
|
||||||
assert len(values) == 1
|
assert len(values) == 1
|
||||||
return values[-1]
|
return values[-1]
|
||||||
@ -1320,12 +1316,12 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def _take_next_step(
|
def _take_next_step(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||||
return self._consume_next_step(
|
return self._consume_next_step(
|
||||||
[
|
[
|
||||||
a
|
a
|
||||||
@ -1341,10 +1337,10 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def _iter_next_step(
|
def _iter_next_step(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||||
"""Take a single step in the thought-action-observation loop.
|
"""Take a single step in the thought-action-observation loop.
|
||||||
@ -1404,7 +1400,7 @@ class AgentExecutor(Chain):
|
|||||||
yield output
|
yield output
|
||||||
return
|
return
|
||||||
|
|
||||||
actions: List[AgentAction]
|
actions: list[AgentAction]
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
actions = [output]
|
actions = [output]
|
||||||
else:
|
else:
|
||||||
@ -1418,8 +1414,8 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def _perform_agent_action(
|
def _perform_agent_action(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
agent_action: AgentAction,
|
agent_action: AgentAction,
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> AgentStep:
|
) -> AgentStep:
|
||||||
@ -1457,12 +1453,12 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
async def _atake_next_step(
|
async def _atake_next_step(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||||
return self._consume_next_step(
|
return self._consume_next_step(
|
||||||
[
|
[
|
||||||
a
|
a
|
||||||
@ -1478,10 +1474,10 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
async def _aiter_next_step(
|
async def _aiter_next_step(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||||
"""Take a single step in the thought-action-observation loop.
|
"""Take a single step in the thought-action-observation loop.
|
||||||
@ -1539,7 +1535,7 @@ class AgentExecutor(Chain):
|
|||||||
yield output
|
yield output
|
||||||
return
|
return
|
||||||
|
|
||||||
actions: List[AgentAction]
|
actions: list[AgentAction]
|
||||||
if isinstance(output, AgentAction):
|
if isinstance(output, AgentAction):
|
||||||
actions = [output]
|
actions = [output]
|
||||||
else:
|
else:
|
||||||
@ -1563,8 +1559,8 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
async def _aperform_agent_action(
|
async def _aperform_agent_action(
|
||||||
self,
|
self,
|
||||||
name_to_tool_map: Dict[str, BaseTool],
|
name_to_tool_map: dict[str, BaseTool],
|
||||||
color_mapping: Dict[str, str],
|
color_mapping: dict[str, str],
|
||||||
agent_action: AgentAction,
|
agent_action: AgentAction,
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> AgentStep:
|
) -> AgentStep:
|
||||||
@ -1604,9 +1600,9 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run text through and get agent response."""
|
"""Run text through and get agent response."""
|
||||||
# Construct a mapping of tool name to tool for easy lookup
|
# Construct a mapping of tool name to tool for easy lookup
|
||||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||||
@ -1614,7 +1610,7 @@ class AgentExecutor(Chain):
|
|||||||
color_mapping = get_color_mapping(
|
color_mapping = get_color_mapping(
|
||||||
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
||||||
)
|
)
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||||
# Let's start tracking the number of iterations and time elapsed
|
# Let's start tracking the number of iterations and time elapsed
|
||||||
iterations = 0
|
iterations = 0
|
||||||
time_elapsed = 0.0
|
time_elapsed = 0.0
|
||||||
@ -1651,9 +1647,9 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Async run text through and get agent response."""
|
"""Async run text through and get agent response."""
|
||||||
# Construct a mapping of tool name to tool for easy lookup
|
# Construct a mapping of tool name to tool for easy lookup
|
||||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||||
@ -1661,7 +1657,7 @@ class AgentExecutor(Chain):
|
|||||||
color_mapping = get_color_mapping(
|
color_mapping = get_color_mapping(
|
||||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||||
)
|
)
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||||
# Let's start tracking the number of iterations and time elapsed
|
# Let's start tracking the number of iterations and time elapsed
|
||||||
iterations = 0
|
iterations = 0
|
||||||
time_elapsed = 0.0
|
time_elapsed = 0.0
|
||||||
@ -1712,7 +1708,7 @@ class AgentExecutor(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_tool_return(
|
def _get_tool_return(
|
||||||
self, next_step_output: Tuple[AgentAction, str]
|
self, next_step_output: tuple[AgentAction, str]
|
||||||
) -> Optional[AgentFinish]:
|
) -> Optional[AgentFinish]:
|
||||||
"""Check if the tool is a returning tool."""
|
"""Check if the tool is a returning tool."""
|
||||||
agent_action, observation = next_step_output
|
agent_action, observation = next_step_output
|
||||||
@ -1730,8 +1726,8 @@ class AgentExecutor(Chain):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_intermediate_steps(
|
def _prepare_intermediate_steps(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> List[Tuple[AgentAction, str]]:
|
) -> list[tuple[AgentAction, str]]:
|
||||||
if (
|
if (
|
||||||
isinstance(self.trim_intermediate_steps, int)
|
isinstance(self.trim_intermediate_steps, int)
|
||||||
and self.trim_intermediate_steps > 0
|
and self.trim_intermediate_steps > 0
|
||||||
@ -1744,7 +1740,7 @@ class AgentExecutor(Chain):
|
|||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: Union[Dict[str, Any], Any],
|
input: Union[dict[str, Any], Any],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[AddableDict]:
|
) -> Iterator[AddableDict]:
|
||||||
@ -1770,12 +1766,11 @@ class AgentExecutor(Chain):
|
|||||||
yield_actions=True,
|
yield_actions=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
for step in iterator:
|
yield from iterator
|
||||||
yield step
|
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self,
|
self,
|
||||||
input: Union[Dict[str, Any], Any],
|
input: Union[dict[str, Any], Any],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[AddableDict]:
|
) -> AsyncIterator[AddableDict]:
|
||||||
|
@ -3,15 +3,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -53,7 +49,7 @@ class AgentExecutorIterator:
|
|||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[UUID] = None,
|
run_id: Optional[UUID] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
@ -90,17 +86,17 @@ class AgentExecutorIterator:
|
|||||||
self.yield_actions = yield_actions
|
self.yield_actions = yield_actions
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
_inputs: Dict[str, str]
|
_inputs: dict[str, str]
|
||||||
callbacks: Callbacks
|
callbacks: Callbacks
|
||||||
tags: Optional[list[str]]
|
tags: Optional[list[str]]
|
||||||
metadata: Optional[Dict[str, Any]]
|
metadata: Optional[dict[str, Any]]
|
||||||
run_name: Optional[str]
|
run_name: Optional[str]
|
||||||
run_id: Optional[UUID]
|
run_id: Optional[UUID]
|
||||||
include_run_info: bool
|
include_run_info: bool
|
||||||
yield_actions: bool
|
yield_actions: bool
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Dict[str, str]:
|
def inputs(self) -> dict[str, str]:
|
||||||
"""The inputs to the AgentExecutor."""
|
"""The inputs to the AgentExecutor."""
|
||||||
return self._inputs
|
return self._inputs
|
||||||
|
|
||||||
@ -120,12 +116,12 @@ class AgentExecutorIterator:
|
|||||||
self.inputs = self.inputs
|
self.inputs = self.inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name_to_tool_map(self) -> Dict[str, BaseTool]:
|
def name_to_tool_map(self) -> dict[str, BaseTool]:
|
||||||
"""A mapping of tool names to tools."""
|
"""A mapping of tool names to tools."""
|
||||||
return {tool.name: tool for tool in self.agent_executor.tools}
|
return {tool.name: tool for tool in self.agent_executor.tools}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def color_mapping(self) -> Dict[str, str]:
|
def color_mapping(self) -> dict[str, str]:
|
||||||
"""A mapping of tool names to colors."""
|
"""A mapping of tool names to colors."""
|
||||||
return get_color_mapping(
|
return get_color_mapping(
|
||||||
[tool.name for tool in self.agent_executor.tools],
|
[tool.name for tool in self.agent_executor.tools],
|
||||||
@ -156,7 +152,7 @@ class AgentExecutorIterator:
|
|||||||
|
|
||||||
def make_final_outputs(
|
def make_final_outputs(
|
||||||
self,
|
self,
|
||||||
outputs: Dict[str, Any],
|
outputs: dict[str, Any],
|
||||||
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun],
|
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun],
|
||||||
) -> AddableDict:
|
) -> AddableDict:
|
||||||
# have access to intermediate steps by design in iterator,
|
# have access to intermediate steps by design in iterator,
|
||||||
@ -171,7 +167,7 @@ class AgentExecutorIterator:
|
|||||||
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||||
return prepared_outputs
|
return prepared_outputs
|
||||||
|
|
||||||
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]:
|
def __iter__(self: AgentExecutorIterator) -> Iterator[AddableDict]:
|
||||||
logger.debug("Initialising AgentExecutorIterator")
|
logger.debug("Initialising AgentExecutorIterator")
|
||||||
self.reset()
|
self.reset()
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@ -311,7 +307,7 @@ class AgentExecutorIterator:
|
|||||||
|
|
||||||
def _process_next_step_output(
|
def _process_next_step_output(
|
||||||
self,
|
self,
|
||||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> AddableDict:
|
) -> AddableDict:
|
||||||
"""
|
"""
|
||||||
@ -339,7 +335,7 @@ class AgentExecutorIterator:
|
|||||||
|
|
||||||
async def _aprocess_next_step_output(
|
async def _aprocess_next_step_output(
|
||||||
self,
|
self,
|
||||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> AddableDict:
|
) -> AddableDict:
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
@ -26,7 +26,7 @@ def _get_default_system_message() -> SystemMessage:
|
|||||||
|
|
||||||
def create_conversational_retrieval_agent(
|
def create_conversational_retrieval_agent(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: List[BaseTool],
|
tools: list[BaseTool],
|
||||||
remember_intermediate_steps: bool = True,
|
remember_intermediate_steps: bool = True,
|
||||||
memory_key: str = "chat_history",
|
memory_key: str = "chat_history",
|
||||||
system_message: Optional[SystemMessage] = None,
|
system_message: Optional[SystemMessage] = None,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""VectorStore agent."""
|
"""VectorStore agent."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks.base import BaseCallbackManager
|
from langchain_core.callbacks.base import BaseCallbackManager
|
||||||
@ -36,7 +36,7 @@ def create_vectorstore_agent(
|
|||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
agent_executor_kwargs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Construct a VectorStore agent from an LLM and tools.
|
"""Construct a VectorStore agent from an LLM and tools.
|
||||||
@ -129,7 +129,7 @@ def create_vectorstore_router_agent(
|
|||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
prefix: str = ROUTER_PREFIX,
|
prefix: str = ROUTER_PREFIX,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
agent_executor_kwargs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""Construct a VectorStore router agent from an LLM and tools.
|
"""Construct a VectorStore router agent from an LLM and tools.
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
"""Toolkit for interacting with a vector store."""
|
"""Toolkit for interacting with a vector store."""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.tools.base import BaseToolkit
|
from langchain_core.tools.base import BaseToolkit
|
||||||
@ -31,7 +29,7 @@ class VectorStoreToolkit(BaseToolkit):
|
|||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> List[BaseTool]:
|
def get_tools(self) -> list[BaseTool]:
|
||||||
"""Get the tools in the toolkit."""
|
"""Get the tools in the toolkit."""
|
||||||
try:
|
try:
|
||||||
from langchain_community.tools.vectorstore.tool import (
|
from langchain_community.tools.vectorstore.tool import (
|
||||||
@ -66,16 +64,16 @@ class VectorStoreToolkit(BaseToolkit):
|
|||||||
class VectorStoreRouterToolkit(BaseToolkit):
|
class VectorStoreRouterToolkit(BaseToolkit):
|
||||||
"""Toolkit for routing between Vector Stores."""
|
"""Toolkit for routing between Vector Stores."""
|
||||||
|
|
||||||
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
vectorstores: list[VectorStoreInfo] = Field(exclude=True)
|
||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> List[BaseTool]:
|
def get_tools(self) -> list[BaseTool]:
|
||||||
"""Get the tools in the toolkit."""
|
"""Get the tools in the toolkit."""
|
||||||
tools: List[BaseTool] = []
|
tools: list[BaseTool] = []
|
||||||
try:
|
try:
|
||||||
from langchain_community.tools.vectorstore.tool import (
|
from langchain_community.tools.vectorstore.tool import (
|
||||||
VectorStoreQATool,
|
VectorStoreQATool,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, List, Optional, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
@ -48,7 +49,7 @@ class ChatAgent(Agent):
|
|||||||
return "Thought:"
|
return "Thought:"
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||||
if not isinstance(agent_scratchpad, str):
|
if not isinstance(agent_scratchpad, str):
|
||||||
@ -72,7 +73,7 @@ class ChatAgent(Agent):
|
|||||||
validate_tools_single_input(class_name=cls.__name__, tools=tools)
|
validate_tools_single_input(class_name=cls.__name__, tools=tools)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> list[str]:
|
||||||
return ["Observation:"]
|
return ["Observation:"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -83,7 +84,7 @@ class ChatAgent(Agent):
|
|||||||
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
||||||
human_message: str = HUMAN_MESSAGE,
|
human_message: str = HUMAN_MESSAGE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
"""Create a prompt from a list of tools.
|
"""Create a prompt from a list of tools.
|
||||||
|
|
||||||
@ -132,7 +133,7 @@ class ChatAgent(Agent):
|
|||||||
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
||||||
human_message: str = HUMAN_MESSAGE,
|
human_message: str = HUMAN_MESSAGE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools.
|
"""Construct an agent from an LLM and tools.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Pattern, Union
|
from re import Pattern
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
@ -71,7 +72,7 @@ class ConversationalAgent(Agent):
|
|||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
ai_prefix: str = "AI",
|
ai_prefix: str = "AI",
|
||||||
human_prefix: str = "Human",
|
human_prefix: str = "Human",
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
) -> PromptTemplate:
|
) -> PromptTemplate:
|
||||||
"""Create prompt in the style of the zero-shot agent.
|
"""Create prompt in the style of the zero-shot agent.
|
||||||
|
|
||||||
@ -120,7 +121,7 @@ class ConversationalAgent(Agent):
|
|||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
ai_prefix: str = "AI",
|
ai_prefix: str = "AI",
|
||||||
human_prefix: str = "Human",
|
human_prefix: str = "Human",
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools.
|
"""Construct an agent from an LLM and tools.
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Optional, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
@ -77,7 +78,7 @@ class ConversationalChatAgent(Agent):
|
|||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
system_message: str = PREFIX,
|
system_message: str = PREFIX,
|
||||||
human_message: str = SUFFIX,
|
human_message: str = SUFFIX,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
output_parser: Optional[BaseOutputParser] = None,
|
output_parser: Optional[BaseOutputParser] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
"""Create a prompt for the agent.
|
"""Create a prompt for the agent.
|
||||||
@ -116,10 +117,10 @@ class ConversationalChatAgent(Agent):
|
|||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> List[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||||
thoughts: List[BaseMessage] = []
|
thoughts: list[BaseMessage] = []
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
thoughts.append(AIMessage(content=action.log))
|
thoughts.append(AIMessage(content=action.log))
|
||||||
human_message = HumanMessage(
|
human_message = HumanMessage(
|
||||||
@ -137,7 +138,7 @@ class ConversationalChatAgent(Agent):
|
|||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
system_message: str = PREFIX,
|
system_message: str = PREFIX,
|
||||||
human_message: str = SUFFIX,
|
human_message: str = SUFFIX,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools.
|
"""Construct an agent from an LLM and tools.
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
|
|
||||||
|
|
||||||
def format_log_to_str(
|
def format_log_to_str(
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
observation_prefix: str = "Observation: ",
|
observation_prefix: str = "Observation: ",
|
||||||
llm_prefix: str = "Thought: ",
|
llm_prefix: str = "Thought: ",
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
def format_log_to_messages(
|
def format_log_to_messages(
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
template_tool_response: str = "{observation}",
|
template_tool_response: str = "{observation}",
|
||||||
) -> List[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Construct the scratchpad that lets the agent continue its thought process.
|
"""Construct the scratchpad that lets the agent continue its thought process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -18,7 +16,7 @@ def format_log_to_messages(
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: The scratchpad.
|
List[BaseMessage]: The scratchpad.
|
||||||
"""
|
"""
|
||||||
thoughts: List[BaseMessage] = []
|
thoughts: list[BaseMessage] = []
|
||||||
for action, observation in intermediate_steps:
|
for action, observation in intermediate_steps:
|
||||||
thoughts.append(AIMessage(content=action.log))
|
thoughts.append(AIMessage(content=action.log))
|
||||||
human_message = HumanMessage(
|
human_message = HumanMessage(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentActionMessageLog
|
from langchain_core.agents import AgentAction, AgentActionMessageLog
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
||||||
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
|||||||
|
|
||||||
def _convert_agent_action_to_messages(
|
def _convert_agent_action_to_messages(
|
||||||
agent_action: AgentAction, observation: str
|
agent_action: AgentAction, observation: str
|
||||||
) -> List[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Convert an agent action to a message.
|
"""Convert an agent action to a message.
|
||||||
|
|
||||||
This code is used to reconstruct the original AI message from the agent action.
|
This code is used to reconstruct the original AI message from the agent action.
|
||||||
@ -54,8 +54,8 @@ def _create_function_message(
|
|||||||
|
|
||||||
|
|
||||||
def format_to_openai_function_messages(
|
def format_to_openai_function_messages(
|
||||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
intermediate_steps: Sequence[tuple[AgentAction, str]],
|
||||||
) -> List[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -40,8 +40,8 @@ def _create_tool_message(
|
|||||||
|
|
||||||
|
|
||||||
def format_to_tool_messages(
|
def format_to_tool_messages(
|
||||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
intermediate_steps: Sequence[tuple[AgentAction, str]],
|
||||||
) -> List[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""Convert (AgentAction, tool output) tuples into ToolMessages.
|
"""Convert (AgentAction, tool output) tuples into ToolMessages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
|
|
||||||
|
|
||||||
def format_xml(
|
def format_xml(
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Format the intermediate steps as XML.
|
"""Format the intermediate steps as XML.
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Load agent."""
|
"""Load agent."""
|
||||||
|
|
||||||
from typing import Any, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
@ -15,7 +16,7 @@ def create_json_chat_agent(
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
prompt: ChatPromptTemplate,
|
prompt: ChatPromptTemplate,
|
||||||
stop_sequence: Union[bool, List[str]] = True,
|
stop_sequence: Union[bool, list[str]] = True,
|
||||||
tools_renderer: ToolsRenderer = render_text_description,
|
tools_renderer: ToolsRenderer = render_text_description,
|
||||||
template_tool_response: str = TEMPLATE_TOOL_RESPONSE,
|
template_tool_response: str = TEMPLATE_TOOL_RESPONSE,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@ -20,7 +20,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
|
|||||||
|
|
||||||
|
|
||||||
def _load_agent_from_tools(
|
def _load_agent_from_tools(
|
||||||
config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any
|
config: dict, llm: BaseLanguageModel, tools: list[Tool], **kwargs: Any
|
||||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||||
config_type = config.pop("_type")
|
config_type = config.pop("_type")
|
||||||
if config_type not in AGENT_TO_CLASS:
|
if config_type not in AGENT_TO_CLASS:
|
||||||
@ -35,7 +35,7 @@ def _load_agent_from_tools(
|
|||||||
def load_agent_from_config(
|
def load_agent_from_config(
|
||||||
config: dict,
|
config: dict,
|
||||||
llm: Optional[BaseLanguageModel] = None,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
tools: Optional[List[Tool]] = None,
|
tools: Optional[list[Tool]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||||
"""Load agent from Config Dict.
|
"""Load agent from Config Dict.
|
||||||
@ -130,7 +130,7 @@ def _load_agent_from_file(
|
|||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
elif file_path.suffix[1:] == "yaml":
|
elif file_path.suffix[1:] == "yaml":
|
||||||
with open(file_path, "r") as f:
|
with open(file_path) as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, NamedTuple, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import BaseCallbackManager
|
from langchain_core.callbacks import BaseCallbackManager
|
||||||
@ -83,7 +84,7 @@ class ZeroShotAgent(Agent):
|
|||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
) -> PromptTemplate:
|
) -> PromptTemplate:
|
||||||
"""Create prompt in the style of the zero shot agent.
|
"""Create prompt in the style of the zero shot agent.
|
||||||
|
|
||||||
@ -118,7 +119,7 @@ class ZeroShotAgent(Agent):
|
|||||||
prefix: str = PREFIX,
|
prefix: str = PREFIX,
|
||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools.
|
"""Construct an agent from an LLM and tools.
|
||||||
@ -183,7 +184,7 @@ class MRKLChain(AgentExecutor):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_chains(
|
def from_chains(
|
||||||
cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any
|
cls, llm: BaseLanguageModel, chains: list[ChainConfig], **kwargs: Any
|
||||||
) -> AgentExecutor:
|
) -> AgentExecutor:
|
||||||
"""User-friendly way to initialize the MRKL chain.
|
"""User-friendly way to initialize the MRKL chain.
|
||||||
|
|
||||||
|
@ -2,18 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,7 +107,7 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
|
|||||||
|
|
||||||
|
|
||||||
def _is_assistants_builtin_tool(
|
def _is_assistants_builtin_tool(
|
||||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Determine if tool corresponds to OpenAI Assistants built-in."""
|
"""Determine if tool corresponds to OpenAI Assistants built-in."""
|
||||||
assistants_builtin_tools = ("code_interpreter", "file_search")
|
assistants_builtin_tools = ("code_interpreter", "file_search")
|
||||||
@ -123,8 +119,8 @@ def _is_assistants_builtin_tool(
|
|||||||
|
|
||||||
|
|
||||||
def _get_assistants_tool(
|
def _get_assistants_tool(
|
||||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Convert a raw function/class to an OpenAI tool.
|
"""Convert a raw function/class to an OpenAI tool.
|
||||||
|
|
||||||
Note that OpenAI assistants supports several built-in tools,
|
Note that OpenAI assistants supports several built-in tools,
|
||||||
@ -137,14 +133,14 @@ def _get_assistants_tool(
|
|||||||
|
|
||||||
|
|
||||||
OutputType = Union[
|
OutputType = Union[
|
||||||
List[OpenAIAssistantAction],
|
list[OpenAIAssistantAction],
|
||||||
OpenAIAssistantFinish,
|
OpenAIAssistantFinish,
|
||||||
List["ThreadMessage"],
|
list["ThreadMessage"],
|
||||||
List["RequiredActionFunctionToolCall"],
|
list["RequiredActionFunctionToolCall"],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||||
"""Run an OpenAI Assistant.
|
"""Run an OpenAI Assistant.
|
||||||
|
|
||||||
Example using OpenAI tools:
|
Example using OpenAI tools:
|
||||||
@ -498,7 +494,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def _parse_intermediate_steps(
|
def _parse_intermediate_steps(
|
||||||
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
|
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
last_action, last_output = intermediate_steps[-1]
|
last_action, last_output = intermediate_steps[-1]
|
||||||
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
||||||
@ -652,7 +648,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|||||||
return run
|
return run
|
||||||
|
|
||||||
async def _aparse_intermediate_steps(
|
async def _aparse_intermediate_steps(
|
||||||
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
|
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
last_action, last_output = intermediate_steps[-1]
|
last_action, last_output = intermediate_steps[-1]
|
||||||
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Memory used to save agent output AND intermediate steps."""
|
"""Memory used to save agent output AND intermediate steps."""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
@ -43,19 +43,19 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
|
|||||||
format_as_tools: bool = False
|
format_as_tools: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def buffer(self) -> List[BaseMessage]:
|
def buffer(self) -> list[BaseMessage]:
|
||||||
"""String buffer of memory."""
|
"""String buffer of memory."""
|
||||||
return self.chat_memory.messages
|
return self.chat_memory.messages
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def memory_variables(self) -> List[str]:
|
def memory_variables(self) -> list[str]:
|
||||||
"""Always return list of memory variables.
|
"""Always return list of memory variables.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer.
|
"""Return history buffer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -74,7 +74,7 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
|
|||||||
)
|
)
|
||||||
return {self.memory_key: final_buffer}
|
return {self.memory_key: final_buffer}
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
def save_context(self, inputs: dict[str, Any], outputs: dict[str, Any]) -> None:
|
||||||
"""Save context from this conversation to buffer. Pruned.
|
"""Save context from this conversation to buffer. Pruned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||||
|
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Type, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
@ -51,11 +52,11 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
output_parser: Type[OpenAIFunctionsAgentOutputParser] = (
|
output_parser: type[OpenAIFunctionsAgentOutputParser] = (
|
||||||
OpenAIFunctionsAgentOutputParser
|
OpenAIFunctionsAgentOutputParser
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_allowed_tools(self) -> List[str]:
|
def get_allowed_tools(self) -> list[str]:
|
||||||
"""Get allowed tools."""
|
"""Get allowed tools."""
|
||||||
return [t.name for t in self.tools]
|
return [t.name for t in self.tools]
|
||||||
|
|
||||||
@ -81,19 +82,19 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Get input keys. Input refers to user input here."""
|
"""Get input keys. Input refers to user input here."""
|
||||||
return ["input"]
|
return ["input"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def functions(self) -> List[dict]:
|
def functions(self) -> list[dict]:
|
||||||
"""Get functions."""
|
"""Get functions."""
|
||||||
|
|
||||||
return [dict(convert_to_openai_function(t)) for t in self.tools]
|
return [dict(convert_to_openai_function(t)) for t in self.tools]
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
with_functions: bool = True,
|
with_functions: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -135,7 +136,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -168,7 +169,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
@ -213,7 +214,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create prompt for this agent.
|
"""Create prompt for this agent.
|
||||||
|
|
||||||
@ -227,7 +228,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
A prompt template to pass into this agent.
|
A prompt template to pass into this agent.
|
||||||
"""
|
"""
|
||||||
_prompts = extra_prompt_messages or []
|
_prompts = extra_prompt_messages or []
|
||||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
if system_message:
|
if system_message:
|
||||||
messages = [system_message]
|
messages = [system_message]
|
||||||
else:
|
else:
|
||||||
@ -248,7 +249,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Sequence
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||||
@ -34,7 +35,7 @@ from langchain.agents.format_scratchpad.openai_functions import (
|
|||||||
_FunctionsAgentAction = AgentActionMessageLog
|
_FunctionsAgentAction = AgentActionMessageLog
|
||||||
|
|
||||||
|
|
||||||
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
|
def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Parse an AI message."""
|
"""Parse an AI message."""
|
||||||
if not isinstance(message, AIMessage):
|
if not isinstance(message, AIMessage):
|
||||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||||
@ -58,7 +59,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
|
|||||||
f"the `arguments` JSON does not contain `actions` key."
|
f"the `arguments` JSON does not contain `actions` key."
|
||||||
)
|
)
|
||||||
|
|
||||||
final_tools: List[AgentAction] = []
|
final_tools: list[AgentAction] = []
|
||||||
for tool_schema in tools:
|
for tool_schema in tools:
|
||||||
if "action" in tool_schema:
|
if "action" in tool_schema:
|
||||||
_tool_input = tool_schema["action"]
|
_tool_input = tool_schema["action"]
|
||||||
@ -112,7 +113,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
|
|
||||||
def get_allowed_tools(self) -> List[str]:
|
def get_allowed_tools(self) -> list[str]:
|
||||||
"""Get allowed tools."""
|
"""Get allowed tools."""
|
||||||
return [t.name for t in self.tools]
|
return [t.name for t in self.tools]
|
||||||
|
|
||||||
@ -127,12 +128,12 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Get input keys. Input refers to user input here."""
|
"""Get input keys. Input refers to user input here."""
|
||||||
return ["input"]
|
return ["input"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def functions(self) -> List[dict]:
|
def functions(self) -> list[dict]:
|
||||||
"""Get the functions for the agent."""
|
"""Get the functions for the agent."""
|
||||||
enum_vals = [t.name for t in self.tools]
|
enum_vals = [t.name for t in self.tools]
|
||||||
tool_selection = {
|
tool_selection = {
|
||||||
@ -194,10 +195,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Given input, decided what to do.
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -224,10 +225,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Async given input, decided what to do.
|
"""Async given input, decided what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -258,7 +259,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
"""Create prompt for this agent.
|
"""Create prompt for this agent.
|
||||||
|
|
||||||
@ -272,7 +273,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
A prompt template to pass into this agent.
|
A prompt template to pass into this agent.
|
||||||
"""
|
"""
|
||||||
_prompts = extra_prompt_messages or []
|
_prompts = extra_prompt_messages or []
|
||||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||||
if system_message:
|
if system_message:
|
||||||
messages = [system_message]
|
messages = [system_message]
|
||||||
else:
|
else:
|
||||||
@ -293,7 +294,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
tools: Sequence[BaseTool],
|
tools: Sequence[BaseTool],
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
content="You are a helpful AI assistant."
|
content="You are a helpful AI assistant."
|
||||||
),
|
),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -77,7 +77,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: List[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
raise ValueError("This output parser only works on ChatGeneration output")
|
raise ValueError("This output parser only works on ChatGeneration output")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
@ -15,12 +15,12 @@ OpenAIToolAgentAction = ToolAgentAction
|
|||||||
|
|
||||||
def parse_ai_message_to_openai_tool_action(
|
def parse_ai_message_to_openai_tool_action(
|
||||||
message: BaseMessage,
|
message: BaseMessage,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Parse an AI message potentially containing tool_calls."""
|
"""Parse an AI message potentially containing tool_calls."""
|
||||||
tool_actions = parse_ai_message_to_tool_action(message)
|
tool_actions = parse_ai_message_to_tool_action(message)
|
||||||
if isinstance(tool_actions, AgentFinish):
|
if isinstance(tool_actions, AgentFinish):
|
||||||
return tool_actions
|
return tool_actions
|
||||||
final_actions: List[AgentAction] = []
|
final_actions: list[AgentAction] = []
|
||||||
for action in tool_actions:
|
for action in tool_actions:
|
||||||
if isinstance(action, ToolAgentAction):
|
if isinstance(action, ToolAgentAction):
|
||||||
final_actions.append(
|
final_actions.append(
|
||||||
@ -54,12 +54,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
return "openai-tools-agent-output-parser"
|
return "openai-tools-agent-output-parser"
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: List[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
raise ValueError("This output parser only works on ChatGeneration output")
|
raise ValueError("This output parser only works on ChatGeneration output")
|
||||||
message = result[0].message
|
message = result[0].message
|
||||||
return parse_ai_message_to_openai_tool_action(message)
|
return parse_ai_message_to_openai_tool_action(message)
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||||
raise ValueError("Can only parse messages")
|
raise ValueError("Can only parse messages")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Pattern, Union
|
from re import Pattern
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -21,12 +21,12 @@ class ToolAgentAction(AgentActionMessageLog): # type: ignore[override]
|
|||||||
|
|
||||||
def parse_ai_message_to_tool_action(
|
def parse_ai_message_to_tool_action(
|
||||||
message: BaseMessage,
|
message: BaseMessage,
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
"""Parse an AI message potentially containing tool_calls."""
|
"""Parse an AI message potentially containing tool_calls."""
|
||||||
if not isinstance(message, AIMessage):
|
if not isinstance(message, AIMessage):
|
||||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||||
|
|
||||||
actions: List = []
|
actions: list = []
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
tool_calls = message.tool_calls
|
tool_calls = message.tool_calls
|
||||||
else:
|
else:
|
||||||
@ -91,12 +91,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
|||||||
return "tools-agent-output-parser"
|
return "tools-agent-output-parser"
|
||||||
|
|
||||||
def parse_result(
|
def parse_result(
|
||||||
self, result: List[Generation], *, partial: bool = False
|
self, result: list[Generation], *, partial: bool = False
|
||||||
) -> Union[List[AgentAction], AgentFinish]:
|
) -> Union[list[AgentAction], AgentFinish]:
|
||||||
if not isinstance(result[0], ChatGeneration):
|
if not isinstance(result[0], ChatGeneration):
|
||||||
raise ValueError("This output parser only works on ChatGeneration output")
|
raise ValueError("This output parser only works on ChatGeneration output")
|
||||||
message = result[0].message
|
message = result[0].message
|
||||||
return parse_ai_message_to_tool_action(message)
|
return parse_ai_message_to_tool_action(message)
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||||
raise ValueError("Can only parse messages")
|
raise ValueError("Can only parse messages")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Optional, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
@ -20,7 +21,7 @@ def create_react_agent(
|
|||||||
output_parser: Optional[AgentOutputParser] = None,
|
output_parser: Optional[AgentOutputParser] = None,
|
||||||
tools_renderer: ToolsRenderer = render_text_description,
|
tools_renderer: ToolsRenderer = render_text_description,
|
||||||
*,
|
*,
|
||||||
stop_sequence: Union[bool, List[str]] = True,
|
stop_sequence: Union[bool, list[str]] = True,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
"""Create an agent that uses ReAct prompting.
|
"""Create an agent that uses ReAct prompting.
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
@ -65,7 +66,7 @@ class ReActDocstoreAgent(Agent):
|
|||||||
return "Observation: "
|
return "Observation: "
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> list[str]:
|
||||||
return ["\nObservation:"]
|
return ["\nObservation:"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -122,7 +123,7 @@ class DocstoreExplorer:
|
|||||||
return self._paragraphs[0]
|
return self._paragraphs[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _paragraphs(self) -> List[str]:
|
def _paragraphs(self) -> list[str]:
|
||||||
if self.document is None:
|
if self.document is None:
|
||||||
raise ValueError("Cannot get paragraphs without a document")
|
raise ValueError("Cannot get paragraphs without a document")
|
||||||
return self.document.page_content.split("\n\n")
|
return self.document.page_content.split("\n\n")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||||
@ -12,7 +12,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _construct_agent_scratchpad(
|
def _construct_agent_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
if len(intermediate_steps) == 0:
|
if len(intermediate_steps) == 0:
|
||||||
return ""
|
return ""
|
||||||
@ -26,7 +26,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
|||||||
f"you return as final answer):\n{thoughts}"
|
f"you return as final answer):\n{thoughts}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||||
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
|
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
|
||||||
intermediate_steps
|
intermediate_steps
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
@ -49,7 +50,7 @@ class StructuredChatAgent(Agent):
|
|||||||
return "Thought:"
|
return "Thought:"
|
||||||
|
|
||||||
def _construct_scratchpad(
|
def _construct_scratchpad(
|
||||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||||
) -> str:
|
) -> str:
|
||||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||||
if not isinstance(agent_scratchpad, str):
|
if not isinstance(agent_scratchpad, str):
|
||||||
@ -74,7 +75,7 @@ class StructuredChatAgent(Agent):
|
|||||||
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
|
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> List[str]:
|
def _stop(self) -> list[str]:
|
||||||
return ["Observation:"]
|
return ["Observation:"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -85,8 +86,8 @@ class StructuredChatAgent(Agent):
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
) -> BasePromptTemplate:
|
) -> BasePromptTemplate:
|
||||||
tool_strings = []
|
tool_strings = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@ -117,8 +118,8 @@ class StructuredChatAgent(Agent):
|
|||||||
suffix: str = SUFFIX,
|
suffix: str = SUFFIX,
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||||
input_variables: Optional[List[str]] = None,
|
input_variables: Optional[list[str]] = None,
|
||||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Agent:
|
) -> Agent:
|
||||||
"""Construct an agent from an LLM and tools."""
|
"""Construct an agent from an LLM and tools."""
|
||||||
@ -157,7 +158,7 @@ def create_structured_chat_agent(
|
|||||||
prompt: ChatPromptTemplate,
|
prompt: ChatPromptTemplate,
|
||||||
tools_renderer: ToolsRenderer = render_text_description_and_args,
|
tools_renderer: ToolsRenderer = render_text_description_and_args,
|
||||||
*,
|
*,
|
||||||
stop_sequence: Union[bool, List[str]] = True,
|
stop_sequence: Union[bool, list[str]] = True,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
"""Create an agent aimed at supporting tools with multiple inputs.
|
"""Create an agent aimed at supporting tools with multiple inputs.
|
||||||
|
|
||||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Pattern, Union
|
from re import Pattern
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Callable, List, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction
|
from langchain_core.agents import AgentAction
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -12,7 +13,7 @@ from langchain.agents.format_scratchpad.tools import (
|
|||||||
)
|
)
|
||||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||||
|
|
||||||
MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]]
|
MessageFormatter = Callable[[Sequence[tuple[AgentAction, str]]], list[BaseMessage]]
|
||||||
|
|
||||||
|
|
||||||
def create_tool_calling_agent(
|
def create_tool_calling_agent(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Interface for tools."""
|
"""Interface for tools."""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
@ -20,7 +20,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
|
|||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
requested_tool_name: str,
|
requested_tool_name: str,
|
||||||
available_tool_names: List[str],
|
available_tool_names: list[str],
|
||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
@ -33,7 +33,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
|
|||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
requested_tool_name: str,
|
requested_tool_name: str,
|
||||||
available_tool_names: List[str],
|
available_tool_names: list[str],
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Type, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain.agents.agent import BaseSingleActionAgent
|
from langchain.agents.agent import BaseSingleActionAgent
|
||||||
from langchain.agents.agent_types import AgentType
|
from langchain.agents.agent_types import AgentType
|
||||||
@ -12,9 +12,9 @@ from langchain.agents.react.base import ReActDocstoreAgent
|
|||||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||||
from langchain.agents.structured_chat.base import StructuredChatAgent
|
from langchain.agents.structured_chat.base import StructuredChatAgent
|
||||||
|
|
||||||
AGENT_TYPE = Union[Type[BaseSingleActionAgent], Type[OpenAIMultiFunctionsAgent]]
|
AGENT_TYPE = Union[type[BaseSingleActionAgent], type[OpenAIMultiFunctionsAgent]]
|
||||||
|
|
||||||
AGENT_TO_CLASS: Dict[AgentType, AGENT_TYPE] = {
|
AGENT_TO_CLASS: dict[AgentType, AGENT_TYPE] = {
|
||||||
AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent,
|
AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent,
|
||||||
AgentType.REACT_DOCSTORE: ReActDocstoreAgent,
|
AgentType.REACT_DOCSTORE: ReActDocstoreAgent,
|
||||||
AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent,
|
AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, List, Sequence, Tuple, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
@ -38,13 +39,13 @@ class XMLAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tools: List[BaseTool]
|
tools: list[BaseTool]
|
||||||
"""List of tools this agent has access to."""
|
"""List of tools this agent has access to."""
|
||||||
llm_chain: LLMChain
|
llm_chain: LLMChain
|
||||||
"""Chain to use to predict action."""
|
"""Chain to use to predict action."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
return ["input"]
|
return ["input"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -60,7 +61,7 @@ class XMLAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -84,7 +85,7 @@ class XMLAgent(BaseSingleActionAgent):
|
|||||||
|
|
||||||
async def aplan(
|
async def aplan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]],
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
@ -113,7 +114,7 @@ def create_xml_agent(
|
|||||||
prompt: BasePromptTemplate,
|
prompt: BasePromptTemplate,
|
||||||
tools_renderer: ToolsRenderer = render_text_description,
|
tools_renderer: ToolsRenderer = render_text_description,
|
||||||
*,
|
*,
|
||||||
stop_sequence: Union[bool, List[str]] = True,
|
stop_sequence: Union[bool, list[str]] = True,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
"""Create an agent that uses XML to format its logic.
|
"""Create an agent that uses XML to format its logic.
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any, Literal, Union, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import AsyncCallbackHandler
|
from langchain_core.callbacks import AsyncCallbackHandler
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
@ -25,7 +26,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|||||||
self.done = asyncio.Event()
|
self.done = asyncio.Event()
|
||||||
|
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
# If two calls are made in a row, this resets the state
|
# If two calls are made in a row, this resets the state
|
||||||
self.done.clear()
|
self.done.clear()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.outputs import LLMResult
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
answer_prefix_tokens: Optional[List[str]] = None,
|
answer_prefix_tokens: Optional[list[str]] = None,
|
||||||
strip_tokens: bool = True,
|
strip_tokens: bool = True,
|
||||||
stream_prefix: bool = False,
|
stream_prefix: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -62,7 +62,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|||||||
self.answer_reached = False
|
self.answer_reached = False
|
||||||
|
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
# If two calls are made in a row, this resets the state
|
# If two calls are made in a row, this resets the state
|
||||||
self.done.clear()
|
self.done.clear()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Callback Handler streams to stdout on new llm token."""
|
"""Callback Handler streams to stdout on new llm token."""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
answer_prefix_tokens: Optional[List[str]] = None,
|
answer_prefix_tokens: Optional[list[str]] = None,
|
||||||
strip_tokens: bool = True,
|
strip_tokens: bool = True,
|
||||||
stream_prefix: bool = False,
|
stream_prefix: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -63,7 +63,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
self.answer_reached = False
|
self.answer_reached = False
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
self.answer_reached = False
|
self.answer_reached = False
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@ -20,7 +21,7 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
|
||||||
|
|
||||||
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
|
def _extract_scheme_and_domain(url: str) -> tuple[str, str]:
|
||||||
"""Extract the scheme + domain from a given URL.
|
"""Extract the scheme + domain from a given URL.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -215,7 +216,7 @@ try:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -223,7 +224,7 @@ try:
|
|||||||
return [self.question_key]
|
return [self.question_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Expect output key.
|
"""Expect output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -243,7 +244,7 @@ try:
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_limit_to_domains(cls, values: Dict) -> Any:
|
def validate_limit_to_domains(cls, values: dict) -> Any:
|
||||||
"""Check that allowed domains are valid."""
|
"""Check that allowed domains are valid."""
|
||||||
# This check must be a pre=True check, so that a default of None
|
# This check must be a pre=True check, so that a default of None
|
||||||
# won't be set to limit_to_domains if it's not provided.
|
# won't be set to limit_to_domains if it's not provided.
|
||||||
@ -275,9 +276,9 @@ try:
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
api_url = self.api_request_chain.predict(
|
api_url = self.api_request_chain.predict(
|
||||||
@ -308,9 +309,9 @@ try:
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = (
|
_run_manager = (
|
||||||
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
)
|
)
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
|
import builtins
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@ -46,7 +47,7 @@ def _get_verbosity() -> bool:
|
|||||||
return get_verbose()
|
return get_verbose()
|
||||||
|
|
||||||
|
|
||||||
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||||
"""Abstract base class for creating structured sequences of calls to components.
|
"""Abstract base class for creating structured sequences of calls to components.
|
||||||
|
|
||||||
Chains should be used to encode a sequence of calls to components like
|
Chains should be used to encode a sequence of calls to components like
|
||||||
@ -86,13 +87,13 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||||
will be printed to the console. Defaults to the global `verbose` value,
|
will be printed to the console. Defaults to the global `verbose` value,
|
||||||
accessible via `langchain.globals.get_verbose()`."""
|
accessible via `langchain.globals.get_verbose()`."""
|
||||||
tags: Optional[List[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
"""Optional list of tags associated with the chain. Defaults to None.
|
"""Optional list of tags associated with the chain. Defaults to None.
|
||||||
These tags will be associated with each call to this chain,
|
These tags will be associated with each call to this chain,
|
||||||
and passed as arguments to the handlers defined in `callbacks`.
|
and passed as arguments to the handlers defined in `callbacks`.
|
||||||
You can use these to eg identify a specific instance of a chain with its use case.
|
You can use these to eg identify a specific instance of a chain with its use case.
|
||||||
"""
|
"""
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[dict[str, Any]] = None
|
||||||
"""Optional metadata associated with the chain. Defaults to None.
|
"""Optional metadata associated with the chain. Defaults to None.
|
||||||
This metadata will be associated with each call to this chain,
|
This metadata will be associated with each call to this chain,
|
||||||
and passed as arguments to the handlers defined in `callbacks`.
|
and passed as arguments to the handlers defined in `callbacks`.
|
||||||
@ -107,7 +108,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||||
@ -115,7 +116,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
# This is correct, but pydantic typings/mypy don't think so.
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
return create_model( # type: ignore[call-overload]
|
return create_model( # type: ignore[call-overload]
|
||||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||||
@ -123,10 +124,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: dict[str, Any],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callbacks = config.get("callbacks")
|
callbacks = config.get("callbacks")
|
||||||
tags = config.get("tags")
|
tags = config.get("tags")
|
||||||
@ -162,7 +163,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
else self._call(inputs)
|
else self._call(inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
final_outputs: dict[str, Any] = self.prep_outputs(
|
||||||
inputs, outputs, return_only_outputs
|
inputs, outputs, return_only_outputs
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
@ -176,10 +177,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: dict[str, Any],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
callbacks = config.get("callbacks")
|
callbacks = config.get("callbacks")
|
||||||
tags = config.get("tags")
|
tags = config.get("tags")
|
||||||
@ -213,7 +214,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._acall(inputs)
|
else await self._acall(inputs)
|
||||||
)
|
)
|
||||||
final_outputs: Dict[str, Any] = await self.aprep_outputs(
|
final_outputs: dict[str, Any] = await self.aprep_outputs(
|
||||||
inputs, outputs, return_only_outputs
|
inputs, outputs, return_only_outputs
|
||||||
)
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
@ -231,7 +232,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Any:
|
def raise_callback_manager_deprecation(cls, values: dict) -> Any:
|
||||||
"""Raise deprecation warning if callback_manager is used."""
|
"""Raise deprecation warning if callback_manager is used."""
|
||||||
if values.get("callback_manager") is not None:
|
if values.get("callback_manager") is not None:
|
||||||
if values.get("callbacks") is not None:
|
if values.get("callbacks") is not None:
|
||||||
@ -261,15 +262,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Keys expected to be in the chain input."""
|
"""Keys expected to be in the chain input."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Keys expected to be in the chain output."""
|
"""Keys expected to be in the chain output."""
|
||||||
|
|
||||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def _validate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||||
"""Check that all inputs are present."""
|
"""Check that all inputs are present."""
|
||||||
if not isinstance(inputs, dict):
|
if not isinstance(inputs, dict):
|
||||||
_input_keys = set(self.input_keys)
|
_input_keys = set(self.input_keys)
|
||||||
@ -289,7 +290,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
if missing_keys:
|
if missing_keys:
|
||||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||||
|
|
||||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||||
missing_keys = set(self.output_keys).difference(outputs)
|
missing_keys = set(self.output_keys).difference(outputs)
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||||
@ -297,9 +298,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Execute the chain.
|
"""Execute the chain.
|
||||||
|
|
||||||
This is a private method that is not user-facing. It is only called within
|
This is a private method that is not user-facing. It is only called within
|
||||||
@ -319,9 +320,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously execute the chain.
|
"""Asynchronously execute the chain.
|
||||||
|
|
||||||
This is a private method that is not user-facing. It is only called within
|
This is a private method that is not user-facing. It is only called within
|
||||||
@ -345,15 +346,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: Union[Dict[str, Any], Any],
|
inputs: Union[dict[str, Any], Any],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Execute the chain.
|
"""Execute the chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -396,15 +397,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
||||||
async def acall(
|
async def acall(
|
||||||
self,
|
self,
|
||||||
inputs: Union[Dict[str, Any], Any],
|
inputs: Union[dict[str, Any], Any],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Asynchronously execute the chain.
|
"""Asynchronously execute the chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -445,10 +446,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
def prep_outputs(
|
def prep_outputs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
outputs: Dict[str, str],
|
outputs: dict[str, str],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -471,10 +472,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
async def aprep_outputs(
|
async def aprep_outputs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
outputs: Dict[str, str],
|
outputs: dict[str, str],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -495,7 +496,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
else:
|
else:
|
||||||
return {**inputs, **outputs}
|
return {**inputs, **outputs}
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||||
"""Prepare chain inputs, including adding inputs from memory.
|
"""Prepare chain inputs, including adding inputs from memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -519,7 +520,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
inputs = dict(inputs, **external_context)
|
inputs = dict(inputs, **external_context)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
async def aprep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||||
"""Prepare chain inputs, including adding inputs from memory.
|
"""Prepare chain inputs, including adding inputs from memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -557,8 +558,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Convenience method for executing chain.
|
"""Convenience method for executing chain.
|
||||||
@ -628,8 +629,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Convenience method for executing chain.
|
"""Convenience method for executing chain.
|
||||||
@ -695,7 +696,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> dict:
|
||||||
"""Dictionary representation of chain.
|
"""Dictionary representation of chain.
|
||||||
|
|
||||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||||
@ -763,7 +764,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
|
|
||||||
@deprecated("0.1.0", alternative="batch", removal="1.0")
|
@deprecated("0.1.0", alternative="batch", removal="1.0")
|
||||||
def apply(
|
def apply(
|
||||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[builtins.dict[str, str]]:
|
||||||
"""Call the chain on all inputs in the list."""
|
"""Call the chain on all inputs in the list."""
|
||||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Base interface for chains combining documents."""
|
"""Base interface for chains combining documents."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -47,22 +47,22 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"CombineDocumentsInput",
|
"CombineDocumentsInput",
|
||||||
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
**{self.input_key: (list[Document], None)}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"CombineDocumentsOutput",
|
"CombineDocumentsOutput",
|
||||||
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -70,14 +70,14 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
||||||
"""Return the prompt length given the documents passed in.
|
"""Return the prompt length given the documents passed in.
|
||||||
|
|
||||||
This can be used by a caller to determine whether passing in a list
|
This can be used by a caller to determine whether passing in a list
|
||||||
@ -96,7 +96,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
def combine_docs(self, docs: list[Document], **kwargs: Any) -> tuple[str, dict]:
|
||||||
"""Combine documents into a single string.
|
"""Combine documents into a single string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -111,8 +111,8 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], **kwargs: Any
|
self, docs: list[Document], **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine documents into a single string.
|
"""Combine documents into a single string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -127,9 +127,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, List[Document]],
|
inputs: dict[str, list[Document]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
docs = inputs[self.input_key]
|
docs = inputs[self.input_key]
|
||||||
@ -143,9 +143,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, List[Document]],
|
inputs: dict[str, list[Document]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
docs = inputs[self.input_key]
|
docs = inputs[self.input_key]
|
||||||
@ -229,7 +229,7 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
combine_docs_chain: BaseCombineDocumentsChain
|
combine_docs_chain: BaseCombineDocumentsChain
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -237,7 +237,7 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -246,7 +246,7 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return create_model(
|
return create_model(
|
||||||
"AnalyzeDocumentChain",
|
"AnalyzeDocumentChain",
|
||||||
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||||
@ -254,20 +254,20 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return self.combine_docs_chain.get_output_schema(config)
|
return self.combine_docs_chain.get_output_schema(config)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Split document into chunks and pass to CombineDocumentsChain."""
|
"""Split document into chunks and pass to CombineDocumentsChain."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
document = inputs[self.input_key]
|
document = inputs[self.input_key]
|
||||||
docs = self.text_splitter.create_documents([document])
|
docs = self.text_splitter.create_documents([document])
|
||||||
# Other keys are assumed to be needed for LLM prediction
|
# Other keys are assumed to be needed for LLM prediction
|
||||||
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||||
other_keys[self.combine_docs_chain.input_key] = docs
|
other_keys[self.combine_docs_chain.input_key] = docs
|
||||||
return self.combine_docs_chain(
|
return self.combine_docs_chain(
|
||||||
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -113,20 +113,20 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return create_model(
|
return create_model(
|
||||||
"MapReduceDocumentsOutput",
|
"MapReduceDocumentsOutput",
|
||||||
**{
|
**{
|
||||||
self.output_key: (str, None),
|
self.output_key: (str, None),
|
||||||
"intermediate_steps": (List[str], None),
|
"intermediate_steps": (list[str], None),
|
||||||
}, # type: ignore[call-overload]
|
}, # type: ignore[call-overload]
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().get_output_schema(config)
|
return super().get_output_schema(config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -143,7 +143,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_reduce_chain(cls, values: Dict) -> Any:
|
def get_reduce_chain(cls, values: dict) -> Any:
|
||||||
"""For backwards compatibility."""
|
"""For backwards compatibility."""
|
||||||
if "combine_document_chain" in values:
|
if "combine_document_chain" in values:
|
||||||
if "reduce_documents_chain" in values:
|
if "reduce_documents_chain" in values:
|
||||||
@ -167,7 +167,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
def get_return_intermediate_steps(cls, values: dict) -> Any:
|
||||||
"""For backwards compatibility."""
|
"""For backwards compatibility."""
|
||||||
if "return_map_steps" in values:
|
if "return_map_steps" in values:
|
||||||
values["return_intermediate_steps"] = values["return_map_steps"]
|
values["return_intermediate_steps"] = values["return_map_steps"]
|
||||||
@ -176,7 +176,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||||
"""Get default document variable name, if not provided."""
|
"""Get default document variable name, if not provided."""
|
||||||
if "llm_chain" not in values:
|
if "llm_chain" not in values:
|
||||||
raise ValueError("llm_chain must be provided")
|
raise ValueError("llm_chain must be provided")
|
||||||
@ -227,11 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine documents in a map reduce manner.
|
"""Combine documents in a map reduce manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reducing the results.
|
Combine by mapping first chain over all documents, then reducing the results.
|
||||||
@ -258,11 +258,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine documents in a map reduce manner.
|
"""Combine documents in a map reduce manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reducing the results.
|
Combine by mapping first chain over all documents, then reducing the results.
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -79,7 +80,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Key in output of llm_chain to rank on."""
|
"""Key in output of llm_chain to rank on."""
|
||||||
answer_key: str
|
answer_key: str
|
||||||
"""Key in output of llm_chain to return as answer."""
|
"""Key in output of llm_chain to return as answer."""
|
||||||
metadata_keys: Optional[List[str]] = None
|
metadata_keys: Optional[list[str]] = None
|
||||||
"""Additional metadata from the chosen document to return."""
|
"""Additional metadata from the chosen document to return."""
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
"""Return intermediate steps.
|
"""Return intermediate steps.
|
||||||
@ -92,19 +93,19 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
schema: Dict[str, Any] = {
|
schema: dict[str, Any] = {
|
||||||
self.output_key: (str, None),
|
self.output_key: (str, None),
|
||||||
}
|
}
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
schema["intermediate_steps"] = (List[str], None)
|
schema["intermediate_steps"] = (list[str], None)
|
||||||
if self.metadata_keys:
|
if self.metadata_keys:
|
||||||
schema.update({key: (Any, None) for key in self.metadata_keys})
|
schema.update({key: (Any, None) for key in self.metadata_keys})
|
||||||
|
|
||||||
return create_model("MapRerankOutput", **schema)
|
return create_model("MapRerankOutput", **schema)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -140,7 +141,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||||
"""Get default document variable name, if not provided."""
|
"""Get default document variable name, if not provided."""
|
||||||
if "llm_chain" not in values:
|
if "llm_chain" not in values:
|
||||||
raise ValueError("llm_chain must be provided")
|
raise ValueError("llm_chain must be provided")
|
||||||
@ -163,8 +164,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine documents in a map rerank manner.
|
"""Combine documents in a map rerank manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reranking the results.
|
Combine by mapping first chain over all documents, then reranking the results.
|
||||||
@ -187,8 +188,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return self._process_results(docs, results)
|
return self._process_results(docs, results)
|
||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine documents in a map rerank manner.
|
"""Combine documents in a map rerank manner.
|
||||||
|
|
||||||
Combine by mapping first chain over all documents, then reranking the results.
|
Combine by mapping first chain over all documents, then reranking the results.
|
||||||
@ -212,10 +213,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def _process_results(
|
def _process_results(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
results: Sequence[Union[str, List[str], Dict[str, str]]],
|
results: Sequence[Union[str, list[str], dict[str, str]]],
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
typed_results = cast(List[dict], results)
|
typed_results = cast(list[dict], results)
|
||||||
sorted_res = sorted(
|
sorted_res = sorted(
|
||||||
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
from typing import Any, Callable, Optional, Protocol
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -15,20 +15,20 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|||||||
class CombineDocsProtocol(Protocol):
|
class CombineDocsProtocol(Protocol):
|
||||||
"""Interface for the combine_docs method."""
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
def __call__(self, docs: list[Document], **kwargs: Any) -> str:
|
||||||
"""Interface for the combine_docs method."""
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
class AsyncCombineDocsProtocol(Protocol):
|
class AsyncCombineDocsProtocol(Protocol):
|
||||||
"""Interface for the combine_docs method."""
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
async def __call__(self, docs: list[Document], **kwargs: Any) -> str:
|
||||||
"""Async interface for the combine_docs method."""
|
"""Async interface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
def split_list_of_docs(
|
def split_list_of_docs(
|
||||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||||
) -> List[List[Document]]:
|
) -> list[list[Document]]:
|
||||||
"""Split Documents into subsets that each meet a cumulative length constraint.
|
"""Split Documents into subsets that each meet a cumulative length constraint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -59,7 +59,7 @@ def split_list_of_docs(
|
|||||||
|
|
||||||
|
|
||||||
def collapse_docs(
|
def collapse_docs(
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
combine_document_func: CombineDocsProtocol,
|
combine_document_func: CombineDocsProtocol,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
@ -91,7 +91,7 @@ def collapse_docs(
|
|||||||
|
|
||||||
|
|
||||||
async def acollapse_docs(
|
async def acollapse_docs(
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
combine_document_func: AsyncCombineDocsProtocol,
|
combine_document_func: AsyncCombineDocsProtocol,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
@ -229,11 +229,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine multiple documents recursively.
|
"""Combine multiple documents recursively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -258,11 +258,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Async combine multiple documents recursively.
|
"""Async combine multiple documents recursively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -287,16 +287,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
def _collapse(
|
def _collapse(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[List[Document], dict]:
|
) -> tuple[list[Document], dict]:
|
||||||
result_docs = docs
|
result_docs = docs
|
||||||
length_func = self.combine_documents_chain.prompt_length
|
length_func = self.combine_documents_chain.prompt_length
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
|
||||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
||||||
return self._collapse_chain.run(
|
return self._collapse_chain.run(
|
||||||
input_documents=docs, callbacks=callbacks, **kwargs
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
@ -322,16 +322,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
async def _acollapse(
|
async def _acollapse(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: list[Document],
|
||||||
token_max: Optional[int] = None,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[List[Document], dict]:
|
) -> tuple[list[Document], dict]:
|
||||||
result_docs = docs
|
result_docs = docs
|
||||||
length_func = self.combine_documents_chain.prompt_length
|
length_func = self.combine_documents_chain.prompt_length
|
||||||
num_tokens = length_func(result_docs, **kwargs)
|
num_tokens = length_func(result_docs, **kwargs)
|
||||||
|
|
||||||
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
async def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
||||||
return await self._collapse_chain.arun(
|
return await self._collapse_chain.arun(
|
||||||
input_documents=docs, callbacks=callbacks, **kwargs
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -98,7 +98,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Return the results of the refine steps in the output."""
|
"""Return the results of the refine steps in the output."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -115,7 +115,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
def get_return_intermediate_steps(cls, values: dict) -> Any:
|
||||||
"""For backwards compatibility."""
|
"""For backwards compatibility."""
|
||||||
if "return_refine_steps" in values:
|
if "return_refine_steps" in values:
|
||||||
values["return_intermediate_steps"] = values["return_refine_steps"]
|
values["return_intermediate_steps"] = values["return_refine_steps"]
|
||||||
@ -124,7 +124,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||||
"""Get default document variable name, if not provided."""
|
"""Get default document variable name, if not provided."""
|
||||||
if "initial_llm_chain" not in values:
|
if "initial_llm_chain" not in values:
|
||||||
raise ValueError("initial_llm_chain must be provided")
|
raise ValueError("initial_llm_chain must be provided")
|
||||||
@ -147,8 +147,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -172,8 +172,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return self._construct_result(refine_steps, res)
|
return self._construct_result(refine_steps, res)
|
||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Async combine by mapping a first chain over all, then stuffing
|
"""Async combine by mapping a first chain over all, then stuffing
|
||||||
into a final chain.
|
into a final chain.
|
||||||
|
|
||||||
@ -197,22 +197,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
refine_steps.append(res)
|
refine_steps.append(res)
|
||||||
return self._construct_result(refine_steps, res)
|
return self._construct_result(refine_steps, res)
|
||||||
|
|
||||||
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
|
def _construct_result(self, refine_steps: list[str], res: str) -> tuple[str, dict]:
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
extra_return_dict = {"intermediate_steps": refine_steps}
|
extra_return_dict = {"intermediate_steps": refine_steps}
|
||||||
else:
|
else:
|
||||||
extra_return_dict = {}
|
extra_return_dict = {}
|
||||||
return res, extra_return_dict
|
return res, extra_return_dict
|
||||||
|
|
||||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
def _construct_refine_inputs(self, doc: Document, res: str) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
self.document_variable_name: format_document(doc, self.document_prompt),
|
self.document_variable_name: format_document(doc, self.document_prompt),
|
||||||
self.initial_response_name: res,
|
self.initial_response_name: res,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _construct_initial_inputs(
|
def _construct_initial_inputs(
|
||||||
self, docs: List[Document], **kwargs: Any
|
self, docs: list[Document], **kwargs: Any
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
base_info = {"page_content": docs[0].page_content}
|
base_info = {"page_content": docs[0].page_content}
|
||||||
base_info.update(docs[0].metadata)
|
base_info.update(docs[0].metadata)
|
||||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Chain that combines documents by stuffing into context."""
|
"""Chain that combines documents by stuffing into context."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
@ -29,7 +29,7 @@ def create_stuff_documents_chain(
|
|||||||
document_prompt: Optional[BasePromptTemplate] = None,
|
document_prompt: Optional[BasePromptTemplate] = None,
|
||||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||||
document_variable_name: str = DOCUMENTS_KEY,
|
document_variable_name: str = DOCUMENTS_KEY,
|
||||||
) -> Runnable[Dict[str, Any], Any]:
|
) -> Runnable[dict[str, Any], Any]:
|
||||||
"""Create a chain for passing a list of Documents to a model.
|
"""Create a chain for passing a list of Documents to a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -163,7 +163,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||||
"""Get default document variable name, if not provided.
|
"""Get default document variable name, if not provided.
|
||||||
|
|
||||||
If only one variable is present in the llm_chain.prompt,
|
If only one variable is present in the llm_chain.prompt,
|
||||||
@ -188,13 +188,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
extra_keys = [
|
extra_keys = [
|
||||||
k for k in self.llm_chain.input_keys if k != self.document_variable_name
|
k for k in self.llm_chain.input_keys if k != self.document_variable_name
|
||||||
]
|
]
|
||||||
return super().input_keys + extra_keys
|
return super().input_keys + extra_keys
|
||||||
|
|
||||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
def _get_inputs(self, docs: list[Document], **kwargs: Any) -> dict:
|
||||||
"""Construct inputs from kwargs and docs.
|
"""Construct inputs from kwargs and docs.
|
||||||
|
|
||||||
Format and then join all the documents together into one input with name
|
Format and then join all the documents together into one input with name
|
||||||
@ -220,7 +220,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
|
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
||||||
"""Return the prompt length given the documents passed in.
|
"""Return the prompt length given the documents passed in.
|
||||||
|
|
||||||
This can be used by a caller to determine whether passing in a list
|
This can be used by a caller to determine whether passing in a list
|
||||||
@ -241,8 +241,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return self.llm_chain._get_num_tokens(prompt)
|
return self.llm_chain._get_num_tokens(prompt)
|
||||||
|
|
||||||
def combine_docs(
|
def combine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Stuff all documents into one prompt and pass to LLM.
|
"""Stuff all documents into one prompt and pass to LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -259,8 +259,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||||
|
|
||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Tuple[str, dict]:
|
) -> tuple[str, dict]:
|
||||||
"""Async stuff all documents into one prompt and pass to LLM.
|
"""Async stuff all documents into one prompt and pass to LLM.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
@ -190,15 +190,15 @@ class ConstitutionalChain(Chain):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
chain: LLMChain
|
chain: LLMChain
|
||||||
constitutional_principles: List[ConstitutionalPrinciple]
|
constitutional_principles: list[ConstitutionalPrinciple]
|
||||||
critique_chain: LLMChain
|
critique_chain: LLMChain
|
||||||
revision_chain: LLMChain
|
revision_chain: LLMChain
|
||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_principles(
|
def get_principles(
|
||||||
cls, names: Optional[List[str]] = None
|
cls, names: Optional[list[str]] = None
|
||||||
) -> List[ConstitutionalPrinciple]:
|
) -> list[ConstitutionalPrinciple]:
|
||||||
if names is None:
|
if names is None:
|
||||||
return list(PRINCIPLES.values())
|
return list(PRINCIPLES.values())
|
||||||
else:
|
else:
|
||||||
@ -224,12 +224,12 @@ class ConstitutionalChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys."""
|
"""Input keys."""
|
||||||
return self.chain.input_keys
|
return self.chain.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Output keys."""
|
"""Output keys."""
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
return ["output", "critiques_and_revisions", "initial_output"]
|
return ["output", "critiques_and_revisions", "initial_output"]
|
||||||
@ -237,9 +237,9 @@ class ConstitutionalChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
response = self.chain.run(
|
response = self.chain.run(
|
||||||
**inputs,
|
**inputs,
|
||||||
@ -305,7 +305,7 @@ class ConstitutionalChain(Chain):
|
|||||||
color="yellow",
|
color="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
final_output: Dict[str, Any] = {"output": response}
|
final_output: dict[str, Any] = {"output": response}
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
final_output["initial_output"] = initial_response
|
final_output["initial_output"] = initial_response
|
||||||
final_output["critiques_and_revisions"] = critiques_and_revisions
|
final_output["critiques_and_revisions"] = critiques_and_revisions
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
"""Chain that carries on a conversation and calls an LLM."""
|
"""Chain that carries on a conversation and calls an LLM."""
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
@ -121,7 +119,7 @@ class ConversationChain(LLMChain): # type: ignore[override, override]
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Use this since so some prompt vars come from history."""
|
"""Use this since so some prompt vars come from history."""
|
||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import inspect
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -32,13 +32,13 @@ from langchain.chains.question_answering import load_qa_chain
|
|||||||
|
|
||||||
# Depending on the memory type and configuration, the chat history format may differ.
|
# Depending on the memory type and configuration, the chat history format may differ.
|
||||||
# This needs to be consolidated.
|
# This needs to be consolidated.
|
||||||
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
CHAT_TURN_TYPE = Union[tuple[str, str], BaseMessage]
|
||||||
|
|
||||||
|
|
||||||
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
|
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
|
||||||
|
|
||||||
|
|
||||||
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
|
||||||
buffer = ""
|
buffer = ""
|
||||||
for dialogue_turn in chat_history:
|
for dialogue_turn in chat_history:
|
||||||
if isinstance(dialogue_turn, BaseMessage):
|
if isinstance(dialogue_turn, BaseMessage):
|
||||||
@ -64,7 +64,7 @@ class InputType(BaseModel):
|
|||||||
|
|
||||||
question: str
|
question: str
|
||||||
"""The question to answer."""
|
"""The question to answer."""
|
||||||
chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list)
|
chat_history: list[CHAT_TURN_TYPE] = Field(default_factory=list)
|
||||||
"""The chat history to use for retrieval."""
|
"""The chat history to use for retrieval."""
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
"""Return the retrieved source documents as part of the final result."""
|
"""Return the retrieved source documents as part of the final result."""
|
||||||
return_generated_question: bool = False
|
return_generated_question: bool = False
|
||||||
"""Return the generated question as part of the final result."""
|
"""Return the generated question as part of the final result."""
|
||||||
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
|
get_chat_history: Optional[Callable[[list[CHAT_TURN_TYPE]], str]] = None
|
||||||
"""An optional function to get a string of the chat history.
|
"""An optional function to get a string of the chat history.
|
||||||
If None is provided, will use a default."""
|
If None is provided, will use a default."""
|
||||||
response_if_no_docs_found: Optional[str] = None
|
response_if_no_docs_found: Optional[str] = None
|
||||||
@ -103,17 +103,17 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys."""
|
"""Input keys."""
|
||||||
return ["question", "chat_history"]
|
return ["question", "chat_history"]
|
||||||
|
|
||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> type[BaseModel]:
|
||||||
return InputType
|
return InputType
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return the output keys.
|
"""Return the output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -129,17 +129,17 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs["question"]
|
question = inputs["question"]
|
||||||
get_chat_history = self.get_chat_history or _get_chat_history
|
get_chat_history = self.get_chat_history or _get_chat_history
|
||||||
@ -159,7 +159,7 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
||||||
else:
|
else:
|
||||||
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
||||||
output: Dict[str, Any] = {}
|
output: dict[str, Any] = {}
|
||||||
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
||||||
output[self.output_key] = self.response_if_no_docs_found
|
output[self.output_key] = self.response_if_no_docs_found
|
||||||
else:
|
else:
|
||||||
@ -182,17 +182,17 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs["question"]
|
question = inputs["question"]
|
||||||
get_chat_history = self.get_chat_history or _get_chat_history
|
get_chat_history = self.get_chat_history or _get_chat_history
|
||||||
@ -212,7 +212,7 @@ class BaseConversationalRetrievalChain(Chain):
|
|||||||
else:
|
else:
|
||||||
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
||||||
|
|
||||||
output: Dict[str, Any] = {}
|
output: dict[str, Any] = {}
|
||||||
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
||||||
output[self.output_key] = self.response_if_no_docs_found
|
output[self.output_key] = self.response_if_no_docs_found
|
||||||
else:
|
else:
|
||||||
@ -368,7 +368,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
"""If set, enforces that the documents returned are less than this limit.
|
"""If set, enforces that the documents returned are less than this limit.
|
||||||
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
|
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
|
||||||
|
|
||||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||||
num_docs = len(docs)
|
num_docs = len(docs)
|
||||||
|
|
||||||
if self.max_tokens_limit and isinstance(
|
if self.max_tokens_limit and isinstance(
|
||||||
@ -388,10 +388,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
docs = self.retriever.invoke(
|
docs = self.retriever.invoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
@ -401,10 +401,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
docs = await self.retriever.ainvoke(
|
docs = await self.retriever.ainvoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
@ -420,7 +420,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
chain_type: str = "stuff",
|
chain_type: str = "stuff",
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
condense_question_llm: Optional[BaseLanguageModel] = None,
|
condense_question_llm: Optional[BaseLanguageModel] = None,
|
||||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
combine_docs_chain_kwargs: Optional[dict] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseConversationalRetrievalChain:
|
) -> BaseConversationalRetrievalChain:
|
||||||
@ -485,7 +485,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`ChatVectorDBChain` is deprecated - "
|
"`ChatVectorDBChain` is deprecated - "
|
||||||
"please use `from langchain.chains import ConversationalRetrievalChain`"
|
"please use `from langchain.chains import ConversationalRetrievalChain`"
|
||||||
@ -495,10 +495,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||||
@ -509,10 +509,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||||
|
|
||||||
@ -523,7 +523,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|||||||
vectorstore: VectorStore,
|
vectorstore: VectorStore,
|
||||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||||
chain_type: str = "stuff",
|
chain_type: str = "stuff",
|
||||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
combine_docs_chain_kwargs: Optional[dict] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> BaseConversationalRetrievalChain:
|
) -> BaseConversationalRetrievalChain:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -44,8 +44,8 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
|
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
ignore_indices: Optional[List[str]] = None
|
ignore_indices: Optional[list[str]] = None
|
||||||
include_indices: Optional[List[str]] = None
|
include_indices: Optional[list[str]] = None
|
||||||
input_key: str = "question" #: :meta private:
|
input_key: str = "question" #: :meta private:
|
||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
sample_documents_in_index_info: int = 3
|
sample_documents_in_index_info: int = 3
|
||||||
@ -66,7 +66,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the singular input key.
|
"""Return the singular input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -74,7 +74,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return the singular output key.
|
"""Return the singular output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -84,7 +84,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
else:
|
else:
|
||||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||||
|
|
||||||
def _list_indices(self) -> List[str]:
|
def _list_indices(self) -> list[str]:
|
||||||
all_indices = [
|
all_indices = [
|
||||||
index["index"] for index in self.database.cat.indices(format="json")
|
index["index"] for index in self.database.cat.indices(format="json")
|
||||||
]
|
]
|
||||||
@ -96,7 +96,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
|
|
||||||
return all_indices
|
return all_indices
|
||||||
|
|
||||||
def _get_indices_infos(self, indices: List[str]) -> str:
|
def _get_indices_infos(self, indices: list[str]) -> str:
|
||||||
mappings = self.database.indices.get_mapping(index=",".join(indices))
|
mappings = self.database.indices.get_mapping(index=",".join(indices))
|
||||||
if self.sample_documents_in_index_info > 0:
|
if self.sample_documents_in_index_info > 0:
|
||||||
for k, v in mappings.items():
|
for k, v in mappings.items():
|
||||||
@ -114,15 +114,15 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _search(self, indices: List[str], query: str) -> str:
|
def _search(self, indices: list[str], query: str) -> str:
|
||||||
result = self.database.search(index=",".join(indices), body=query)
|
result = self.database.search(index=",".join(indices), body=query)
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
input_text = f"{inputs[self.input_key]}\nESQuery:"
|
input_text = f"{inputs[self.input_key]}\nESQuery:"
|
||||||
_run_manager.on_text(input_text, verbose=self.verbose)
|
_run_manager.on_text(input_text, verbose=self.verbose)
|
||||||
@ -134,7 +134,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
"indices_info": indices_info,
|
"indices_info": indices_info,
|
||||||
"stop": ["\nESResult:"],
|
"stop": ["\nESResult:"],
|
||||||
}
|
}
|
||||||
intermediate_steps: List = []
|
intermediate_steps: list = []
|
||||||
try:
|
try:
|
||||||
intermediate_steps.append(query_inputs) # input: es generation
|
intermediate_steps.append(query_inputs) # input: es generation
|
||||||
es_cmd = self.query_chain.invoke(
|
es_cmd = self.query_chain.invoke(
|
||||||
@ -163,7 +163,7 @@ class ElasticsearchDatabaseChain(Chain):
|
|||||||
|
|
||||||
intermediate_steps.append(final_result) # output: final answer
|
intermediate_steps.append(final_result) # output: final answer
|
||||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
chain_result: dict[str, Any] = {self.output_key: final_result}
|
||||||
if self.return_intermediate_steps:
|
if self.return_intermediate_steps:
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
return chain_result
|
return chain_result
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||||
@ -9,7 +7,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
|
|||||||
|
|
||||||
|
|
||||||
def generate_example(
|
def generate_example(
|
||||||
examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
|
examples: list[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return another example given a list of examples for a prompt."""
|
"""Return another example given a list of examples for a prompt."""
|
||||||
prompt = FewShotPromptTemplate(
|
prompt = FewShotPromptTemplate(
|
||||||
|
@ -2,7 +2,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
@ -26,7 +27,7 @@ from langchain.chains.llm import LLMChain
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
def _extract_tokens_and_log_probs(response: AIMessage) -> tuple[list[str], list[float]]:
|
||||||
"""Extract tokens and log probabilities from chat model response."""
|
"""Extract tokens and log probabilities from chat model response."""
|
||||||
tokens = []
|
tokens = []
|
||||||
log_probs = []
|
log_probs = []
|
||||||
@ -47,7 +48,7 @@ class QuestionGeneratorChain(LLMChain):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys for the chain."""
|
"""Input keys for the chain."""
|
||||||
return ["user_input", "context", "response"]
|
return ["user_input", "context", "response"]
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ def _low_confidence_spans(
|
|||||||
min_prob: float,
|
min_prob: float,
|
||||||
min_token_gap: int,
|
min_token_gap: int,
|
||||||
num_pad_tokens: int,
|
num_pad_tokens: int,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -117,22 +118,22 @@ class FlareChain(Chain):
|
|||||||
"""Whether to start with retrieval."""
|
"""Whether to start with retrieval."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys for the chain."""
|
"""Input keys for the chain."""
|
||||||
return ["user_input"]
|
return ["user_input"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Output keys for the chain."""
|
"""Output keys for the chain."""
|
||||||
return ["response"]
|
return ["response"]
|
||||||
|
|
||||||
def _do_generation(
|
def _do_generation(
|
||||||
self,
|
self,
|
||||||
questions: List[str],
|
questions: list[str],
|
||||||
user_input: str,
|
user_input: str,
|
||||||
response: str,
|
response: str,
|
||||||
_run_manager: CallbackManagerForChainRun,
|
_run_manager: CallbackManagerForChainRun,
|
||||||
) -> Tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
docs = []
|
docs = []
|
||||||
for question in questions:
|
for question in questions:
|
||||||
@ -153,12 +154,12 @@ class FlareChain(Chain):
|
|||||||
|
|
||||||
def _do_retrieval(
|
def _do_retrieval(
|
||||||
self,
|
self,
|
||||||
low_confidence_spans: List[str],
|
low_confidence_spans: list[str],
|
||||||
_run_manager: CallbackManagerForChainRun,
|
_run_manager: CallbackManagerForChainRun,
|
||||||
user_input: str,
|
user_input: str,
|
||||||
response: str,
|
response: str,
|
||||||
initial_response: str,
|
initial_response: str,
|
||||||
) -> Tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
question_gen_inputs = [
|
question_gen_inputs = [
|
||||||
{
|
{
|
||||||
"user_input": user_input,
|
"user_input": user_input,
|
||||||
@ -187,9 +188,9 @@ class FlareChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
|
||||||
user_input = inputs[self.input_keys[0]]
|
user_input = inputs[self.input_keys[0]]
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
class FinishedOutputParser(BaseOutputParser[tuple[str, bool]]):
|
||||||
"""Output parser that checks if the output is finished."""
|
"""Output parser that checks if the output is finished."""
|
||||||
|
|
||||||
finished_value: str = "FINISHED"
|
finished_value: str = "FINISHED"
|
||||||
"""Value that indicates the output is finished."""
|
"""Value that indicates the output is finished."""
|
||||||
|
|
||||||
def parse(self, text: str) -> Tuple[str, bool]:
|
def parse(self, text: str) -> tuple[str, bool]:
|
||||||
cleaned = text.strip()
|
cleaned = text.strip()
|
||||||
finished = self.finished_value in cleaned
|
finished = self.finished_value in cleaned
|
||||||
return cleaned.replace(self.finished_value, ""), finished
|
return cleaned.replace(self.finished_value, ""), finished
|
||||||
|
@ -6,7 +6,7 @@ https://arxiv.org/abs/2212.10496
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -38,23 +38,23 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys for Hyde's LLM chain."""
|
"""Input keys for Hyde's LLM chain."""
|
||||||
return self.llm_chain.input_schema.model_json_schema()["required"]
|
return self.llm_chain.input_schema.model_json_schema()["required"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Output keys for Hyde's LLM chain."""
|
"""Output keys for Hyde's LLM chain."""
|
||||||
if isinstance(self.llm_chain, LLMChain):
|
if isinstance(self.llm_chain, LLMChain):
|
||||||
return self.llm_chain.output_keys
|
return self.llm_chain.output_keys
|
||||||
else:
|
else:
|
||||||
return ["text"]
|
return ["text"]
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Call the base embeddings."""
|
"""Call the base embeddings."""
|
||||||
return self.base_embeddings.embed_documents(texts)
|
return self.base_embeddings.embed_documents(texts)
|
||||||
|
|
||||||
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
def combine_embeddings(self, embeddings: list[list[float]]) -> list[float]:
|
||||||
"""Combine embeddings into final embeddings."""
|
"""Combine embeddings into final embeddings."""
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -73,7 +73,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
num_vectors = len(embeddings)
|
num_vectors = len(embeddings)
|
||||||
return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)]
|
return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)]
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Generate a hypothetical document and embedded it."""
|
"""Generate a hypothetical document and embedded it."""
|
||||||
var_name = self.input_keys[0]
|
var_name = self.input_keys[0]
|
||||||
result = self.llm_chain.invoke({var_name: text})
|
result = self.llm_chain.invoke({var_name: text})
|
||||||
@ -86,9 +86,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Call the internal llm chain."""
|
"""Call the internal llm chain."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
return self.llm_chain.invoke(
|
return self.llm_chain.invoke(
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -100,7 +101,7 @@ class LLMChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Will be whatever keys the prompt expects.
|
"""Will be whatever keys the prompt expects.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -108,7 +109,7 @@ class LLMChain(Chain):
|
|||||||
return self.prompt.input_variables
|
return self.prompt.input_variables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Will always return text key.
|
"""Will always return text key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -120,15 +121,15 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
response = self.generate([inputs], run_manager=run_manager)
|
response = self.generate([inputs], run_manager=run_manager)
|
||||||
return self.create_outputs(response)[0]
|
return self.create_outputs(response)[0]
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_list: List[Dict[str, Any]],
|
input_list: list[dict[str, Any]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
@ -143,9 +144,9 @@ class LLMChain(Chain):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||||
cast(List, prompts), {"callbacks": callbacks}
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
)
|
)
|
||||||
generations: List[List[Generation]] = []
|
generations: list[list[Generation]] = []
|
||||||
for res in results:
|
for res in results:
|
||||||
if isinstance(res, BaseMessage):
|
if isinstance(res, BaseMessage):
|
||||||
generations.append([ChatGeneration(message=res)])
|
generations.append([ChatGeneration(message=res)])
|
||||||
@ -155,7 +156,7 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self,
|
self,
|
||||||
input_list: List[Dict[str, Any]],
|
input_list: list[dict[str, Any]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
@ -170,9 +171,9 @@ class LLMChain(Chain):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||||
cast(List, prompts), {"callbacks": callbacks}
|
cast(list, prompts), {"callbacks": callbacks}
|
||||||
)
|
)
|
||||||
generations: List[List[Generation]] = []
|
generations: list[list[Generation]] = []
|
||||||
for res in results:
|
for res in results:
|
||||||
if isinstance(res, BaseMessage):
|
if isinstance(res, BaseMessage):
|
||||||
generations.append([ChatGeneration(message=res)])
|
generations.append([ChatGeneration(message=res)])
|
||||||
@ -182,9 +183,9 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
def prep_prompts(
|
def prep_prompts(
|
||||||
self,
|
self,
|
||||||
input_list: List[Dict[str, Any]],
|
input_list: list[dict[str, Any]],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
) -> tuple[list[PromptValue], Optional[list[str]]]:
|
||||||
"""Prepare prompts from inputs."""
|
"""Prepare prompts from inputs."""
|
||||||
stop = None
|
stop = None
|
||||||
if len(input_list) == 0:
|
if len(input_list) == 0:
|
||||||
@ -208,9 +209,9 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
async def aprep_prompts(
|
async def aprep_prompts(
|
||||||
self,
|
self,
|
||||||
input_list: List[Dict[str, Any]],
|
input_list: list[dict[str, Any]],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
) -> tuple[list[PromptValue], Optional[list[str]]]:
|
||||||
"""Prepare prompts from inputs."""
|
"""Prepare prompts from inputs."""
|
||||||
stop = None
|
stop = None
|
||||||
if len(input_list) == 0:
|
if len(input_list) == 0:
|
||||||
@ -233,8 +234,8 @@ class LLMChain(Chain):
|
|||||||
return prompts, stop
|
return prompts, stop
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Utilize the LLM generate method for speed gains."""
|
"""Utilize the LLM generate method for speed gains."""
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
@ -254,8 +255,8 @@ class LLMChain(Chain):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
async def aapply(
|
async def aapply(
|
||||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
) -> List[Dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Utilize the LLM generate method for speed gains."""
|
"""Utilize the LLM generate method for speed gains."""
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
@ -278,7 +279,7 @@ class LLMChain(Chain):
|
|||||||
def _run_output_key(self) -> str:
|
def _run_output_key(self) -> str:
|
||||||
return self.output_key
|
return self.output_key
|
||||||
|
|
||||||
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
|
def create_outputs(self, llm_result: LLMResult) -> list[dict[str, Any]]:
|
||||||
"""Create outputs from response."""
|
"""Create outputs from response."""
|
||||||
result = [
|
result = [
|
||||||
# Get the text of the top generated string.
|
# Get the text of the top generated string.
|
||||||
@ -294,9 +295,9 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||||
return self.create_outputs(response)[0]
|
return self.create_outputs(response)[0]
|
||||||
|
|
||||||
@ -336,7 +337,7 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
def predict_and_parse(
|
def predict_and_parse(
|
||||||
self, callbacks: Callbacks = None, **kwargs: Any
|
self, callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Union[str, List[str], Dict[str, Any]]:
|
) -> Union[str, list[str], dict[str, Any]]:
|
||||||
"""Call predict and then parse the results."""
|
"""Call predict and then parse the results."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The predict_and_parse method is deprecated, "
|
"The predict_and_parse method is deprecated, "
|
||||||
@ -350,7 +351,7 @@ class LLMChain(Chain):
|
|||||||
|
|
||||||
async def apredict_and_parse(
|
async def apredict_and_parse(
|
||||||
self, callbacks: Callbacks = None, **kwargs: Any
|
self, callbacks: Callbacks = None, **kwargs: Any
|
||||||
) -> Union[str, List[str], Dict[str, str]]:
|
) -> Union[str, list[str], dict[str, str]]:
|
||||||
"""Call apredict and then parse the results."""
|
"""Call apredict and then parse the results."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The apredict_and_parse method is deprecated, "
|
"The apredict_and_parse method is deprecated, "
|
||||||
@ -363,8 +364,8 @@ class LLMChain(Chain):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def apply_and_parse(
|
def apply_and_parse(
|
||||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||||
"""Call apply and then parse the results."""
|
"""Call apply and then parse the results."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The apply_and_parse method is deprecated, "
|
"The apply_and_parse method is deprecated, "
|
||||||
@ -374,8 +375,8 @@ class LLMChain(Chain):
|
|||||||
return self._parse_generation(result)
|
return self._parse_generation(result)
|
||||||
|
|
||||||
def _parse_generation(
|
def _parse_generation(
|
||||||
self, generation: List[Dict[str, str]]
|
self, generation: list[dict[str, str]]
|
||||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||||
if self.prompt.output_parser is not None:
|
if self.prompt.output_parser is not None:
|
||||||
return [
|
return [
|
||||||
self.prompt.output_parser.parse(res[self.output_key])
|
self.prompt.output_parser.parse(res[self.output_key])
|
||||||
@ -385,8 +386,8 @@ class LLMChain(Chain):
|
|||||||
return generation
|
return generation
|
||||||
|
|
||||||
async def aapply_and_parse(
|
async def aapply_and_parse(
|
||||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||||
"""Call apply and then parse the results."""
|
"""Call apply and then parse the results."""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The aapply_and_parse method is deprecated, "
|
"The aapply_and_parse method is deprecated, "
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
@ -107,7 +107,7 @@ class LLMCheckerChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
if "llm" in values:
|
if "llm" in values:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
||||||
@ -135,7 +135,7 @@ class LLMCheckerChain(Chain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the singular input key.
|
"""Return the singular input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -143,7 +143,7 @@ class LLMCheckerChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return the singular output key.
|
"""Return the singular output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -152,9 +152,9 @@ class LLMCheckerChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -163,7 +163,7 @@ class LLMMathChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
try:
|
try:
|
||||||
import numexpr # noqa: F401
|
import numexpr # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -183,7 +183,7 @@ class LLMMathChain(Chain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -191,7 +191,7 @@ class LLMMathChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Expect output key.
|
"""Expect output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -221,7 +221,7 @@ class LLMMathChain(Chain):
|
|||||||
|
|
||||||
def _process_llm_result(
|
def _process_llm_result(
|
||||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
llm_output = llm_output.strip()
|
llm_output = llm_output.strip()
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
@ -243,7 +243,7 @@ class LLMMathChain(Chain):
|
|||||||
self,
|
self,
|
||||||
llm_output: str,
|
llm_output: str,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||||
llm_output = llm_output.strip()
|
llm_output = llm_output.strip()
|
||||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||||
@ -263,9 +263,9 @@ class LLMMathChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
_run_manager.on_text(inputs[self.input_key])
|
_run_manager.on_text(inputs[self.input_key])
|
||||||
llm_output = self.llm_chain.predict(
|
llm_output = self.llm_chain.predict(
|
||||||
@ -277,9 +277,9 @@ class LLMMathChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
await _run_manager.on_text(inputs[self.input_key])
|
await _run_manager.on_text(inputs[self.input_key])
|
||||||
llm_output = await self.llm_chain.apredict(
|
llm_output = await self.llm_chain.apredict(
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
@ -112,7 +112,7 @@ class LLMSummarizationCheckerChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
if "llm" in values:
|
if "llm" in values:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
||||||
@ -131,7 +131,7 @@ class LLMSummarizationCheckerChain(Chain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return the singular input key.
|
"""Return the singular input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -139,7 +139,7 @@ class LLMSummarizationCheckerChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return the singular output key.
|
"""Return the singular output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -148,9 +148,9 @@ class LLMSummarizationCheckerChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
all_true = False
|
all_true = False
|
||||||
count = 0
|
count = 0
|
||||||
|
@ -702,7 +702,7 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
|||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||||
with open(file_path, "r") as f:
|
with open(file_path) as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
else:
|
else:
|
||||||
raise ValueError("File type must be json or yaml")
|
raise ValueError("File type must be json or yaml")
|
||||||
|
@ -6,7 +6,8 @@ then combines the results with another one.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
||||||
@ -84,7 +85,7 @@ class MapReduceChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -92,7 +93,7 @@ class MapReduceChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -101,15 +102,15 @@ class MapReduceChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
# Split the larger text into smaller chunks.
|
# Split the larger text into smaller chunks.
|
||||||
doc_text = inputs.pop(self.input_key)
|
doc_text = inputs.pop(self.input_key)
|
||||||
texts = self.text_splitter.split_text(doc_text)
|
texts = self.text_splitter.split_text(doc_text)
|
||||||
docs = [Document(page_content=text) for text in texts]
|
docs = [Document(page_content=text) for text in texts]
|
||||||
_inputs: Dict[str, Any] = {
|
_inputs: dict[str, Any] = {
|
||||||
**inputs,
|
**inputs,
|
||||||
self.combine_documents_chain.input_key: docs,
|
self.combine_documents_chain.input_key: docs,
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Pass input through a moderation endpoint."""
|
"""Pass input through a moderation endpoint."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -42,7 +42,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_environment(cls, values: Dict) -> Any:
|
def validate_environment(cls, values: dict) -> Any:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
openai_api_key = get_from_dict_or_env(
|
openai_api_key = get_from_dict_or_env(
|
||||||
values, "openai_api_key", "OPENAI_API_KEY"
|
values, "openai_api_key", "OPENAI_API_KEY"
|
||||||
@ -78,7 +78,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -86,7 +86,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -108,9 +108,9 @@ class OpenAIModerationChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
text = inputs[self.input_key]
|
text = inputs[self.input_key]
|
||||||
if self.openai_pre_1_0:
|
if self.openai_pre_1_0:
|
||||||
results = self.client.create(text)
|
results = self.client.create(text)
|
||||||
@ -122,9 +122,9 @@ class OpenAIModerationChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if self.openai_pre_1_0:
|
if self.openai_pre_1_0:
|
||||||
return await super()._acall(inputs, run_manager=run_manager)
|
return await super()._acall(inputs, run_manager=run_manager)
|
||||||
text = inputs[self.input_key]
|
text = inputs[self.input_key]
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.caches import BaseCache as BaseCache
|
from langchain_core.caches import BaseCache as BaseCache
|
||||||
@ -68,7 +68,7 @@ class NatBotChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
if "llm" in values:
|
if "llm" in values:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Directly instantiating an NatBotChain with an llm is deprecated. "
|
"Directly instantiating an NatBotChain with an llm is deprecated. "
|
||||||
@ -97,7 +97,7 @@ class NatBotChain(Chain):
|
|||||||
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect url and browser content.
|
"""Expect url and browser content.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -105,7 +105,7 @@ class NatBotChain(Chain):
|
|||||||
return [self.input_url_key, self.input_browser_content_key]
|
return [self.input_url_key, self.input_browser_content_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return command.
|
"""Return command.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -114,9 +114,9 @@ class NatBotChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
url = inputs[self.input_url_key]
|
url = inputs[self.input_url_key]
|
||||||
browser_content = inputs[self.input_browser_content_key]
|
browser_content = inputs[self.input_browser_content_key]
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
"""Methods for creating chains that use OpenAI function-calling APIs."""
|
"""Methods for creating chains that use OpenAI function-calling APIs."""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,7 +43,7 @@ __all__ = [
|
|||||||
|
|
||||||
@deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable")
|
@deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable")
|
||||||
def create_openai_fn_chain(
|
def create_openai_fn_chain(
|
||||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: BasePromptTemplate,
|
prompt: BasePromptTemplate,
|
||||||
*,
|
*,
|
||||||
@ -128,7 +126,7 @@ def create_openai_fn_chain(
|
|||||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||||
openai_functions = [convert_to_openai_function(f) for f in functions]
|
openai_functions = [convert_to_openai_function(f) for f in functions]
|
||||||
output_parser = output_parser or get_openai_output_parser(functions)
|
output_parser = output_parser or get_openai_output_parser(functions)
|
||||||
llm_kwargs: Dict[str, Any] = {
|
llm_kwargs: dict[str, Any] = {
|
||||||
"functions": openai_functions,
|
"functions": openai_functions,
|
||||||
}
|
}
|
||||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||||
@ -148,7 +146,7 @@ def create_openai_fn_chain(
|
|||||||
since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output"
|
since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output"
|
||||||
)
|
)
|
||||||
def create_structured_output_chain(
|
def create_structured_output_chain(
|
||||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: BasePromptTemplate,
|
prompt: BasePromptTemplate,
|
||||||
*,
|
*,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Iterator, List
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||||
@ -21,7 +21,7 @@ class FactWithEvidence(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||||
substring_quote: List[str] = Field(
|
substring_quote: list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Each source should be a direct quote from the context, "
|
"Each source should be a direct quote from the context, "
|
||||||
@ -54,7 +54,7 @@ class QuestionAnswer(BaseModel):
|
|||||||
each sentence contains a body and a list of sources."""
|
each sentence contains a body and a list of sources."""
|
||||||
|
|
||||||
question: str = Field(..., description="Question that was asked")
|
question: str = Field(..., description="Question that was asked")
|
||||||
answer: List[FactWithEvidence] = Field(
|
answer: list[FactWithEvidence] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Body of the answer, each fact should be "
|
"Body of the answer, each fact should be "
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -83,7 +83,7 @@ def create_extraction_chain(
|
|||||||
schema: dict,
|
schema: dict,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage.
|
||||||
@ -170,7 +170,7 @@ def create_extraction_chain_pydantic(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class PydanticSchema(BaseModel):
|
class PydanticSchema(BaseModel):
|
||||||
info: List[pydantic_schema] # type: ignore
|
info: list[pydantic_schema] # type: ignore
|
||||||
|
|
||||||
if hasattr(pydantic_schema, "model_json_schema"):
|
if hasattr(pydantic_schema, "model_json_schema"):
|
||||||
openai_schema = pydantic_schema.model_json_schema()
|
openai_schema = pydantic_schema.model_json_schema()
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
@ -70,7 +70,7 @@ def _format_url(url: str, path_params: dict) -> str:
|
|||||||
return url.format(**new_params)
|
return url.format(**new_params)
|
||||||
|
|
||||||
|
|
||||||
def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -> dict:
|
def _openapi_params_to_json_schema(params: list[Parameter], spec: OpenAPISpec) -> dict:
|
||||||
properties = {}
|
properties = {}
|
||||||
required = []
|
required = []
|
||||||
for p in params:
|
for p in params:
|
||||||
@ -89,7 +89,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
|
|||||||
|
|
||||||
def openapi_spec_to_openai_fn(
|
def openapi_spec_to_openai_fn(
|
||||||
spec: OpenAPISpec,
|
spec: OpenAPISpec,
|
||||||
) -> Tuple[List[Dict[str, Any]], Callable]:
|
) -> tuple[list[dict[str, Any]], Callable]:
|
||||||
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
|
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
|
||||||
functions.
|
functions.
|
||||||
|
|
||||||
@ -208,18 +208,18 @@ class SimpleRequestChain(Chain):
|
|||||||
"""Key to use for the input of the request."""
|
"""Key to use for the input of the request."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run the logic of this chain and return the output."""
|
"""Run the logic of this chain and return the output."""
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
name = inputs[self.input_key].pop("name")
|
name = inputs[self.input_key].pop("name")
|
||||||
@ -257,10 +257,10 @@ def get_openapi_chain(
|
|||||||
llm: Optional[BaseLanguageModel] = None,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
request_chain: Optional[Chain] = None,
|
request_chain: Optional[Chain] = None,
|
||||||
llm_chain_kwargs: Optional[Dict] = None,
|
llm_chain_kwargs: Optional[dict] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
headers: Optional[Dict] = None,
|
headers: Optional[dict] = None,
|
||||||
params: Optional[Dict] = None,
|
params: Optional[dict] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> SequentialChain:
|
) -> SequentialChain:
|
||||||
"""Create a chain for querying an API from a OpenAPI spec.
|
"""Create a chain for querying an API from a OpenAPI spec.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional, Type, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -21,7 +21,7 @@ class AnswerWithSources(BaseModel):
|
|||||||
"""An answer to the question, with sources."""
|
"""An answer to the question, with sources."""
|
||||||
|
|
||||||
answer: str = Field(..., description="Answer to the question that was asked")
|
answer: str = Field(..., description="Answer to the question that was asked")
|
||||||
sources: List[str] = Field(
|
sources: list[str] = Field(
|
||||||
..., description="List of sources used to answer the question"
|
..., description="List of sources used to answer the question"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class AnswerWithSources(BaseModel):
|
|||||||
)
|
)
|
||||||
def create_qa_with_structure_chain(
|
def create_qa_with_structure_chain(
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
schema: Union[dict, Type[BaseModel]],
|
schema: Union[dict, type[BaseModel]],
|
||||||
output_parser: str = "base",
|
output_parser: str = "base",
|
||||||
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
def _resolve_schema_references(schema: Any, definitions: dict[str, Any]) -> Any:
|
||||||
"""
|
"""
|
||||||
Resolve the $ref keys in a JSON schema object using the provided definitions.
|
Resolve the $ref keys in a JSON schema object using the provided definitions.
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Type, Union
|
from typing import Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -51,7 +51,7 @@ If a property is not present and is not required in the function parameters, do
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def create_extraction_chain_pydantic(
|
def create_extraction_chain_pydantic(
|
||||||
pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]],
|
pydantic_schemas: Union[list[type[BaseModel]], type[BaseModel]],
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
system_message: str = _EXTRACTION_TEMPLATE,
|
system_message: str = _EXTRACTION_TEMPLATE,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List, Tuple
|
from typing import Callable
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
@ -21,8 +21,8 @@ class ConditionalPromptSelector(BasePromptSelector):
|
|||||||
|
|
||||||
default_prompt: BasePromptTemplate
|
default_prompt: BasePromptTemplate
|
||||||
"""Default prompt to use if no conditionals match."""
|
"""Default prompt to use if no conditionals match."""
|
||||||
conditionals: List[
|
conditionals: list[
|
||||||
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||||
] = Field(default_factory=list)
|
] = Field(default_factory=list)
|
||||||
"""List of conditionals and prompts to use if the conditionals match."""
|
"""List of conditionals and prompts to use if the conditionals match."""
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
@ -103,18 +103,18 @@ class QAGenerationChain(Chain):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, List]:
|
) -> dict[str, list]:
|
||||||
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
||||||
results = self.llm_chain.generate(
|
results = self.llm_chain.generate(
|
||||||
[{"text": d.page_content} for d in docs], run_manager=run_manager
|
[{"text": d.page_content} for d in docs], run_manager=run_manager
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -103,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -111,7 +111,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
return [self.question_key]
|
return [self.question_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -123,13 +123,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_naming(cls, values: Dict) -> Any:
|
def validate_naming(cls, values: dict) -> Any:
|
||||||
"""Fix backwards compatibility in naming."""
|
"""Fix backwards compatibility in naming."""
|
||||||
if "combine_document_chain" in values:
|
if "combine_document_chain" in values:
|
||||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _split_sources(self, answer: str) -> Tuple[str, str]:
|
def _split_sources(self, answer: str) -> tuple[str, str]:
|
||||||
"""Split sources from answer."""
|
"""Split sources from answer."""
|
||||||
if re.search(r"SOURCES?:", answer, re.IGNORECASE):
|
if re.search(r"SOURCES?:", answer, re.IGNORECASE):
|
||||||
answer, sources = re.split(
|
answer, sources = re.split(
|
||||||
@ -143,17 +143,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
accepts_run_manager = (
|
accepts_run_manager = (
|
||||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||||
@ -167,7 +167,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||||
)
|
)
|
||||||
answer, sources = self._split_sources(answer)
|
answer, sources = self._split_sources(answer)
|
||||||
result: Dict[str, Any] = {
|
result: dict[str, Any] = {
|
||||||
self.answer_key: answer,
|
self.answer_key: answer,
|
||||||
self.sources_answer_key: sources,
|
self.sources_answer_key: sources,
|
||||||
}
|
}
|
||||||
@ -178,17 +178,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
accepts_run_manager = (
|
accepts_run_manager = (
|
||||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||||
@ -201,7 +201,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
|||||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||||
)
|
)
|
||||||
answer, sources = self._split_sources(answer)
|
answer, sources = self._split_sources(answer)
|
||||||
result: Dict[str, Any] = {
|
result: dict[str, Any] = {
|
||||||
self.answer_key: answer,
|
self.answer_key: answer,
|
||||||
self.sources_answer_key: sources,
|
self.sources_answer_key: sources,
|
||||||
}
|
}
|
||||||
@ -225,7 +225,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
input_docs_key: str = "docs" #: :meta private:
|
input_docs_key: str = "docs" #: :meta private:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -234,19 +234,19 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
|
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
return inputs.pop(self.input_docs_key)
|
return inputs.pop(self.input_docs_key)
|
||||||
|
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
return inputs.pop(self.input_docs_key)
|
return inputs.pop(self.input_docs_key)
|
||||||
|
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Question-answering with sources over an index."""
|
"""Question-answering with sources over an index."""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -25,7 +25,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
"""Restrict the docs to return from store based on tokens,
|
"""Restrict the docs to return from store based on tokens,
|
||||||
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||||
|
|
||||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||||
num_docs = len(docs)
|
num_docs = len(docs)
|
||||||
|
|
||||||
if self.reduce_k_below_max_tokens and isinstance(
|
if self.reduce_k_below_max_tokens and isinstance(
|
||||||
@ -43,8 +43,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
return docs[:num_docs]
|
return docs[:num_docs]
|
||||||
|
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
docs = self.retriever.invoke(
|
docs = self.retriever.invoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
@ -52,8 +52,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
docs = await self.retriever.ainvoke(
|
docs = await self.retriever.ainvoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Question-answering with sources over a vector database."""
|
"""Question-answering with sources over a vector database."""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -27,10 +27,10 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
max_tokens_limit: int = 3375
|
max_tokens_limit: int = 3375
|
||||||
"""Restrict the docs to return from store based on tokens,
|
"""Restrict the docs to return from store based on tokens,
|
||||||
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
search_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Extra search args."""
|
"""Extra search args."""
|
||||||
|
|
||||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||||
num_docs = len(docs)
|
num_docs = len(docs)
|
||||||
|
|
||||||
if self.reduce_k_below_max_tokens and isinstance(
|
if self.reduce_k_below_max_tokens and isinstance(
|
||||||
@ -48,8 +48,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
return docs[:num_docs]
|
return docs[:num_docs]
|
||||||
|
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
docs = self.vectorstore.similarity_search(
|
docs = self.vectorstore.similarity_search(
|
||||||
question, k=self.k, **self.search_kwargs
|
question, k=self.k, **self.search_kwargs
|
||||||
@ -57,13 +57,13 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`VectorDBQAWithSourcesChain` is deprecated - "
|
"`VectorDBQAWithSourcesChain` is deprecated - "
|
||||||
"please use `from langchain.chains import RetrievalQAWithSourcesChain`"
|
"please use `from langchain.chains import RetrievalQAWithSourcesChain`"
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
@ -172,7 +173,7 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
|||||||
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
||||||
|
|
||||||
|
|
||||||
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
|
def construct_examples(input_output_pairs: Sequence[tuple[str, dict]]) -> list[dict]:
|
||||||
"""Construct examples from input-output pairs.
|
"""Construct examples from input-output pairs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -267,7 +268,7 @@ def load_query_constructor_chain(
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
document_contents: str,
|
document_contents: str,
|
||||||
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||||
examples: Optional[List] = None,
|
examples: Optional[list] = None,
|
||||||
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||||
allowed_operators: Sequence[Operator] = tuple(Operator),
|
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||||
enable_limit: bool = False,
|
enable_limit: bool = False,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Literal, Optional, Sequence, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain_core.utils import check_package_version
|
from langchain_core.utils import check_package_version
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Load question answering chains."""
|
"""Load question answering chains."""
|
||||||
|
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from langchain_core.retrievers import (
|
from langchain_core.retrievers import (
|
||||||
BaseRetriever,
|
BaseRetriever,
|
||||||
@ -11,7 +11,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
|
|||||||
|
|
||||||
def create_retrieval_chain(
|
def create_retrieval_chain(
|
||||||
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
||||||
combine_docs_chain: Runnable[Dict[str, Any], str],
|
combine_docs_chain: Runnable[dict[str, Any], str],
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
"""Create retrieval chain that retrieves documents and then passes them on.
|
"""Create retrieval chain that retrieves documents and then passes them on.
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -54,7 +54,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Input keys.
|
"""Input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -62,7 +62,7 @@ class BaseRetrievalQA(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Output keys.
|
"""Output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -123,14 +123,14 @@ class BaseRetrievalQA(Chain):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get documents to do question answering over."""
|
"""Get documents to do question answering over."""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run get_relevant_text and llm on input query.
|
"""Run get_relevant_text and llm on input query.
|
||||||
|
|
||||||
If chain has 'return_source_documents' as 'True', returns
|
If chain has 'return_source_documents' as 'True', returns
|
||||||
@ -166,14 +166,14 @@ class BaseRetrievalQA(Chain):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get documents to do question answering over."""
|
"""Get documents to do question answering over."""
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run get_relevant_text and llm on input query.
|
"""Run get_relevant_text and llm on input query.
|
||||||
|
|
||||||
If chain has 'return_source_documents' as 'True', returns
|
If chain has 'return_source_documents' as 'True', returns
|
||||||
@ -266,7 +266,7 @@ class RetrievalQA(BaseRetrievalQA):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
return self.retriever.invoke(
|
return self.retriever.invoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
@ -277,7 +277,7 @@ class RetrievalQA(BaseRetrievalQA):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
return await self.retriever.ainvoke(
|
return await self.retriever.ainvoke(
|
||||||
question, config={"callbacks": run_manager.get_child()}
|
question, config={"callbacks": run_manager.get_child()}
|
||||||
@ -307,12 +307,12 @@ class VectorDBQA(BaseRetrievalQA):
|
|||||||
"""Number of documents to query for."""
|
"""Number of documents to query for."""
|
||||||
search_type: str = "similarity"
|
search_type: str = "similarity"
|
||||||
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
||||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
search_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Extra search args."""
|
"""Extra search args."""
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def raise_deprecation(cls, values: Dict) -> Any:
|
def raise_deprecation(cls, values: dict) -> Any:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`VectorDBQA` is deprecated - "
|
"`VectorDBQA` is deprecated - "
|
||||||
"please use `from langchain.chains import RetrievalQA`"
|
"please use `from langchain.chains import RetrievalQA`"
|
||||||
@ -321,7 +321,7 @@ class VectorDBQA(BaseRetrievalQA):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_search_type(cls, values: Dict) -> Any:
|
def validate_search_type(cls, values: dict) -> Any:
|
||||||
"""Validate search type."""
|
"""Validate search type."""
|
||||||
if "search_type" in values:
|
if "search_type" in values:
|
||||||
search_type = values["search_type"]
|
search_type = values["search_type"]
|
||||||
@ -334,7 +334,7 @@ class VectorDBQA(BaseRetrievalQA):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: CallbackManagerForChainRun,
|
run_manager: CallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
if self.search_type == "similarity":
|
if self.search_type == "similarity":
|
||||||
docs = self.vectorstore.similarity_search(
|
docs = self.vectorstore.similarity_search(
|
||||||
@ -353,7 +353,7 @@ class VectorDBQA(BaseRetrievalQA):
|
|||||||
question: str,
|
question: str,
|
||||||
*,
|
*,
|
||||||
run_manager: AsyncCallbackManagerForChainRun,
|
run_manager: AsyncCallbackManagerForChainRun,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Get docs."""
|
"""Get docs."""
|
||||||
raise NotImplementedError("VectorDBQA does not support async")
|
raise NotImplementedError("VectorDBQA does not support async")
|
||||||
|
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, NamedTuple, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -17,17 +18,17 @@ from langchain.chains.base import Chain
|
|||||||
|
|
||||||
class Route(NamedTuple):
|
class Route(NamedTuple):
|
||||||
destination: Optional[str]
|
destination: Optional[str]
|
||||||
next_inputs: Dict[str, Any]
|
next_inputs: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
class RouterChain(Chain, ABC):
|
class RouterChain(Chain, ABC):
|
||||||
"""Chain that outputs the name of a destination chain and the inputs to it."""
|
"""Chain that outputs the name of a destination chain and the inputs to it."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
return ["destination", "next_inputs"]
|
return ["destination", "next_inputs"]
|
||||||
|
|
||||||
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
|
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||||
"""
|
"""
|
||||||
Route inputs to a destination chain.
|
Route inputs to a destination chain.
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ class RouterChain(Chain, ABC):
|
|||||||
return Route(result["destination"], result["next_inputs"])
|
return Route(result["destination"], result["next_inputs"])
|
||||||
|
|
||||||
async def aroute(
|
async def aroute(
|
||||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
self, inputs: dict[str, Any], callbacks: Callbacks = None
|
||||||
) -> Route:
|
) -> Route:
|
||||||
result = await self.acall(inputs, callbacks=callbacks)
|
result = await self.acall(inputs, callbacks=callbacks)
|
||||||
return Route(result["destination"], result["next_inputs"])
|
return Route(result["destination"], result["next_inputs"])
|
||||||
@ -67,7 +68,7 @@ class MultiRouteChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Will be whatever keys the router chain prompt expects.
|
"""Will be whatever keys the router chain prompt expects.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -75,7 +76,7 @@ class MultiRouteChain(Chain):
|
|||||||
return self.router_chain.input_keys
|
return self.router_chain.input_keys
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Will always return text key.
|
"""Will always return text key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -84,9 +85,9 @@ class MultiRouteChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
route = self.router_chain.route(inputs, callbacks=callbacks)
|
route = self.router_chain.route(inputs, callbacks=callbacks)
|
||||||
@ -109,9 +110,9 @@ class MultiRouteChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
|
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -18,7 +19,7 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
"""Chain that uses embeddings to route between options."""
|
"""Chain that uses embeddings to route between options."""
|
||||||
|
|
||||||
vectorstore: VectorStore
|
vectorstore: VectorStore
|
||||||
routing_keys: List[str] = ["query"]
|
routing_keys: list[str] = ["query"]
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
arbitrary_types_allowed=True,
|
arbitrary_types_allowed=True,
|
||||||
@ -26,7 +27,7 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Will be whatever keys the LLM chain prompt expects.
|
"""Will be whatever keys the LLM chain prompt expects.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -35,18 +36,18 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||||
results = self.vectorstore.similarity_search(_input, k=1)
|
results = self.vectorstore.similarity_search(_input, k=1)
|
||||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||||
results = await self.vectorstore.asimilarity_search(_input, k=1)
|
results = await self.vectorstore.asimilarity_search(_input, k=1)
|
||||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||||
@ -54,8 +55,8 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_names_and_descriptions(
|
def from_names_and_descriptions(
|
||||||
cls,
|
cls,
|
||||||
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||||
vectorstore_cls: Type[VectorStore],
|
vectorstore_cls: type[VectorStore],
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> EmbeddingRouterChain:
|
) -> EmbeddingRouterChain:
|
||||||
@ -72,8 +73,8 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def afrom_names_and_descriptions(
|
async def afrom_names_and_descriptions(
|
||||||
cls,
|
cls,
|
||||||
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||||
vectorstore_cls: Type[VectorStore],
|
vectorstore_cls: type[VectorStore],
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> EmbeddingRouterChain:
|
) -> EmbeddingRouterChain:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Type, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -114,42 +114,42 @@ class LLMRouterChain(RouterChain):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Will be whatever keys the LLM chain prompt expects.
|
"""Will be whatever keys the LLM chain prompt expects.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
"""
|
"""
|
||||||
return self.llm_chain.input_keys
|
return self.llm_chain.input_keys
|
||||||
|
|
||||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||||
super()._validate_outputs(outputs)
|
super()._validate_outputs(outputs)
|
||||||
if not isinstance(outputs["next_inputs"], dict):
|
if not isinstance(outputs["next_inputs"], dict):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
|
|
||||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||||
output = cast(
|
output = cast(
|
||||||
Dict[str, Any],
|
dict[str, Any],
|
||||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
output = cast(
|
output = cast(
|
||||||
Dict[str, Any],
|
dict[str, Any],
|
||||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
@ -163,14 +163,14 @@ class LLMRouterChain(RouterChain):
|
|||||||
return cls(llm_chain=llm_chain, **kwargs)
|
return cls(llm_chain=llm_chain, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
||||||
"""Parser for output of router chain in the multi-prompt chain."""
|
"""Parser for output of router chain in the multi-prompt chain."""
|
||||||
|
|
||||||
default_destination: str = "DEFAULT"
|
default_destination: str = "DEFAULT"
|
||||||
next_inputs_type: Type = str
|
next_inputs_type: type = str
|
||||||
next_inputs_inner_key: str = "input"
|
next_inputs_inner_key: str = "input"
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, Any]:
|
def parse(self, text: str) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
expected_keys = ["destination", "next_inputs"]
|
expected_keys = ["destination", "next_inputs"]
|
||||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -142,14 +142,14 @@ class MultiPromptChain(MultiRouteChain):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
return ["text"]
|
return ["text"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_prompts(
|
def from_prompts(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt_infos: List[Dict[str, str]],
|
prompt_infos: list[dict[str, str]],
|
||||||
default_chain: Optional[Chain] = None,
|
default_chain: Optional[Chain] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MultiPromptChain:
|
) -> MultiPromptChain:
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
@ -31,14 +32,14 @@ class MultiRetrievalQAChain(MultiRouteChain): # type: ignore[override]
|
|||||||
"""Default chain to use when router doesn't map input to one of the destinations."""
|
"""Default chain to use when router doesn't map input to one of the destinations."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
return ["result"]
|
return ["result"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_retrievers(
|
def from_retrievers(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
retriever_infos: List[Dict[str, Any]],
|
retriever_infos: list[dict[str, Any]],
|
||||||
default_retriever: Optional[BaseRetriever] = None,
|
default_retriever: Optional[BaseRetriever] = None,
|
||||||
default_prompt: Optional[PromptTemplate] = None,
|
default_prompt: Optional[PromptTemplate] = None,
|
||||||
default_chain: Optional[Chain] = None,
|
default_chain: Optional[Chain] = None,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -16,9 +16,9 @@ from langchain.chains.base import Chain
|
|||||||
class SequentialChain(Chain):
|
class SequentialChain(Chain):
|
||||||
"""Chain where the outputs of one chain feed directly into next."""
|
"""Chain where the outputs of one chain feed directly into next."""
|
||||||
|
|
||||||
chains: List[Chain]
|
chains: list[Chain]
|
||||||
input_variables: List[str]
|
input_variables: list[str]
|
||||||
output_variables: List[str] #: :meta private:
|
output_variables: list[str] #: :meta private:
|
||||||
return_all: bool = False
|
return_all: bool = False
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
@ -27,7 +27,7 @@ class SequentialChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Return expected input keys to the chain.
|
"""Return expected input keys to the chain.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -35,7 +35,7 @@ class SequentialChain(Chain):
|
|||||||
return self.input_variables
|
return self.input_variables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -44,7 +44,7 @@ class SequentialChain(Chain):
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_chains(cls, values: Dict) -> Any:
|
def validate_chains(cls, values: dict) -> Any:
|
||||||
"""Validate that the correct inputs exist for all chains."""
|
"""Validate that the correct inputs exist for all chains."""
|
||||||
chains = values["chains"]
|
chains = values["chains"]
|
||||||
input_variables = values["input_variables"]
|
input_variables = values["input_variables"]
|
||||||
@ -97,9 +97,9 @@ class SequentialChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
known_values = inputs.copy()
|
known_values = inputs.copy()
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
for i, chain in enumerate(self.chains):
|
for i, chain in enumerate(self.chains):
|
||||||
@ -110,9 +110,9 @@ class SequentialChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
known_values = inputs.copy()
|
known_values = inputs.copy()
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
callbacks = _run_manager.get_child()
|
callbacks = _run_manager.get_child()
|
||||||
@ -127,7 +127,7 @@ class SequentialChain(Chain):
|
|||||||
class SimpleSequentialChain(Chain):
|
class SimpleSequentialChain(Chain):
|
||||||
"""Simple chain where the outputs of one step feed directly into next."""
|
"""Simple chain where the outputs of one step feed directly into next."""
|
||||||
|
|
||||||
chains: List[Chain]
|
chains: list[Chain]
|
||||||
strip_outputs: bool = False
|
strip_outputs: bool = False
|
||||||
input_key: str = "input" #: :meta private:
|
input_key: str = "input" #: :meta private:
|
||||||
output_key: str = "output" #: :meta private:
|
output_key: str = "output" #: :meta private:
|
||||||
@ -138,7 +138,7 @@ class SimpleSequentialChain(Chain):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -146,7 +146,7 @@ class SimpleSequentialChain(Chain):
|
|||||||
return [self.input_key]
|
return [self.input_key]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output key.
|
"""Return output key.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -171,9 +171,9 @@ class SimpleSequentialChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
_input = inputs[self.input_key]
|
_input = inputs[self.input_key]
|
||||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||||
@ -190,9 +190,9 @@ class SimpleSequentialChain(Chain):
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
_input = inputs[self.input_key]
|
_input = inputs[self.input_key]
|
||||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
|
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
@ -27,7 +27,7 @@ class SQLInputWithTables(TypedDict):
|
|||||||
"""Input for a SQL Chain."""
|
"""Input for a SQL Chain."""
|
||||||
|
|
||||||
question: str
|
question: str
|
||||||
table_names_to_use: List[str]
|
table_names_to_use: list[str]
|
||||||
|
|
||||||
|
|
||||||
def create_sql_query_chain(
|
def create_sql_query_chain(
|
||||||
@ -35,7 +35,7 @@ def create_sql_query_chain(
|
|||||||
db: SQLDatabase,
|
db: SQLDatabase,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
k: int = 5,
|
k: int = 5,
|
||||||
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
|
) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
|
||||||
"""Create a chain that generates SQL queries.
|
"""Create a chain that generates SQL queries.
|
||||||
|
|
||||||
*Security Note*: This chain generates SQL queries for the given database.
|
*Security Note*: This chain generates SQL queries for the given database.
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
@ -63,7 +64,7 @@ from pydantic import BaseModel
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def create_openai_fn_runnable(
|
def create_openai_fn_runnable(
|
||||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||||
llm: Runnable,
|
llm: Runnable,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
*,
|
*,
|
||||||
@ -135,7 +136,7 @@ def create_openai_fn_runnable(
|
|||||||
if not functions:
|
if not functions:
|
||||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||||
openai_functions = [convert_to_openai_function(f) for f in functions]
|
openai_functions = [convert_to_openai_function(f) for f in functions]
|
||||||
llm_kwargs_: Dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
|
llm_kwargs_: dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
|
||||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||||
llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]}
|
llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]}
|
||||||
output_parser = output_parser or get_openai_output_parser(functions)
|
output_parser = output_parser or get_openai_output_parser(functions)
|
||||||
@ -181,7 +182,7 @@ def create_openai_fn_runnable(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
def create_structured_output_runnable(
|
def create_structured_output_runnable(
|
||||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||||
llm: Runnable,
|
llm: Runnable,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
*,
|
*,
|
||||||
@ -437,7 +438,7 @@ def create_structured_output_runnable(
|
|||||||
|
|
||||||
|
|
||||||
def _create_openai_tools_runnable(
|
def _create_openai_tools_runnable(
|
||||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
|
tool: Union[dict[str, Any], type[BaseModel], Callable],
|
||||||
llm: Runnable,
|
llm: Runnable,
|
||||||
*,
|
*,
|
||||||
prompt: Optional[BasePromptTemplate],
|
prompt: Optional[BasePromptTemplate],
|
||||||
@ -446,7 +447,7 @@ def _create_openai_tools_runnable(
|
|||||||
first_tool_only: bool,
|
first_tool_only: bool,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
oai_tool = convert_to_openai_tool(tool)
|
oai_tool = convert_to_openai_tool(tool)
|
||||||
llm_kwargs: Dict[str, Any] = {"tools": [oai_tool]}
|
llm_kwargs: dict[str, Any] = {"tools": [oai_tool]}
|
||||||
if enforce_tool_usage:
|
if enforce_tool_usage:
|
||||||
llm_kwargs["tool_choice"] = {
|
llm_kwargs["tool_choice"] = {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@ -462,7 +463,7 @@ def _create_openai_tools_runnable(
|
|||||||
|
|
||||||
|
|
||||||
def _get_openai_tool_output_parser(
|
def _get_openai_tool_output_parser(
|
||||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
|
tool: Union[dict[str, Any], type[BaseModel], Callable],
|
||||||
*,
|
*,
|
||||||
first_tool_only: bool = False,
|
first_tool_only: bool = False,
|
||||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||||
@ -479,7 +480,7 @@ def _get_openai_tool_output_parser(
|
|||||||
|
|
||||||
|
|
||||||
def get_openai_output_parser(
|
def get_openai_output_parser(
|
||||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||||
"""Get the appropriate function output parser given the user functions.
|
"""Get the appropriate function output parser given the user functions.
|
||||||
|
|
||||||
@ -496,7 +497,7 @@ def get_openai_output_parser(
|
|||||||
"""
|
"""
|
||||||
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||||
if len(functions) > 1:
|
if len(functions) > 1:
|
||||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
pydantic_schema: Union[dict, type[BaseModel]] = {
|
||||||
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -510,7 +511,7 @@ def get_openai_output_parser(
|
|||||||
|
|
||||||
|
|
||||||
def _create_openai_json_runnable(
|
def _create_openai_json_runnable(
|
||||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||||
llm: Runnable,
|
llm: Runnable,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
*,
|
*,
|
||||||
@ -537,7 +538,7 @@ def _create_openai_json_runnable(
|
|||||||
|
|
||||||
|
|
||||||
def _create_openai_functions_structured_output_runnable(
|
def _create_openai_functions_structured_output_runnable(
|
||||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||||
llm: Runnable,
|
llm: Runnable,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
*,
|
*,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Load summarizing chains."""
|
"""Load summarizing chains."""
|
||||||
|
|
||||||
from typing import Any, Mapping, Optional, Protocol
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, Protocol
|
||||||
|
|
||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from collections.abc import Awaitable
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -26,13 +27,13 @@ class TransformChain(Chain):
|
|||||||
output_variables["entities"], transform=func())
|
output_variables["entities"], transform=func())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: list[str]
|
||||||
"""The keys expected by the transform's input dictionary."""
|
"""The keys expected by the transform's input dictionary."""
|
||||||
output_variables: List[str]
|
output_variables: list[str]
|
||||||
"""The keys returned by the transform's output dictionary."""
|
"""The keys returned by the transform's output dictionary."""
|
||||||
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
|
transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
|
||||||
"""The transform function."""
|
"""The transform function."""
|
||||||
atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = (
|
atransform_cb: Optional[Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]] = (
|
||||||
Field(None, alias="atransform")
|
Field(None, alias="atransform")
|
||||||
)
|
)
|
||||||
"""The async coroutine transform function."""
|
"""The async coroutine transform function."""
|
||||||
@ -47,7 +48,7 @@ class TransformChain(Chain):
|
|||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Expect input keys.
|
"""Expect input keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -55,7 +56,7 @@ class TransformChain(Chain):
|
|||||||
return self.input_variables
|
return self.input_variables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Return output keys.
|
"""Return output keys.
|
||||||
|
|
||||||
:meta private:
|
:meta private:
|
||||||
@ -64,16 +65,16 @@ class TransformChain(Chain):
|
|||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
return self.transform_cb(inputs)
|
return self.transform_cb(inputs)
|
||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if self.atransform_cb is not None:
|
if self.atransform_cb is not None:
|
||||||
return await self.atransform_cb(inputs)
|
return await self.atransform_cb(inputs)
|
||||||
else:
|
else:
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||||
from importlib import util
|
from importlib import util
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
@ -73,7 +67,7 @@ def init_chat_model(
|
|||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: Optional[str] = None,
|
||||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ...,
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||||
config_prefix: Optional[str] = None,
|
config_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> _ConfigurableModel: ...
|
) -> _ConfigurableModel: ...
|
||||||
@ -87,7 +81,7 @@ def init_chat_model(
|
|||||||
*,
|
*,
|
||||||
model_provider: Optional[str] = None,
|
model_provider: Optional[str] = None,
|
||||||
configurable_fields: Optional[
|
configurable_fields: Optional[
|
||||||
Union[Literal["any"], List[str], Tuple[str, ...]]
|
Union[Literal["any"], list[str], tuple[str, ...]]
|
||||||
] = None,
|
] = None,
|
||||||
config_prefix: Optional[str] = None,
|
config_prefix: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -514,7 +508,7 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
|
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||||
if (
|
if (
|
||||||
not model_provider
|
not model_provider
|
||||||
and ":" in model
|
and ":" in model
|
||||||
@ -554,12 +548,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
default_config: Optional[dict] = None,
|
default_config: Optional[dict] = None,
|
||||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any",
|
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||||
config_prefix: str = "",
|
config_prefix: str = "",
|
||||||
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (),
|
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||||
) -> None:
|
) -> None:
|
||||||
self._default_config: dict = default_config or {}
|
self._default_config: dict = default_config or {}
|
||||||
self._configurable_fields: Union[Literal["any"], List[str]] = (
|
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||||
configurable_fields
|
configurable_fields
|
||||||
if configurable_fields == "any"
|
if configurable_fields == "any"
|
||||||
else list(configurable_fields)
|
else list(configurable_fields)
|
||||||
@ -569,7 +563,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
if config_prefix and not config_prefix.endswith("_")
|
if config_prefix and not config_prefix.endswith("_")
|
||||||
else config_prefix
|
else config_prefix
|
||||||
)
|
)
|
||||||
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list(
|
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
|
||||||
queued_declarative_operations
|
queued_declarative_operations
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -670,7 +664,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
return Union[
|
return Union[
|
||||||
str,
|
str,
|
||||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||||
List[AnyMessage],
|
list[AnyMessage],
|
||||||
]
|
]
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
@ -708,12 +702,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
|
|
||||||
def batch(
|
def batch(
|
||||||
self,
|
self,
|
||||||
inputs: List[LanguageModelInput],
|
inputs: list[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||||
@ -731,12 +725,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
|
|
||||||
async def abatch(
|
async def abatch(
|
||||||
self,
|
self,
|
||||||
inputs: List[LanguageModelInput],
|
inputs: list[LanguageModelInput],
|
||||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Any]:
|
) -> list[Any]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||||
@ -759,7 +753,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[Tuple[int, Union[Any, Exception]]]:
|
) -> Iterator[tuple[int, Union[Any, Exception]]]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||||
@ -782,7 +776,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
*,
|
*,
|
||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[Tuple[int, Any]]:
|
) -> AsyncIterator[tuple[int, Any]]:
|
||||||
config = config or None
|
config = config or None
|
||||||
# If <= 1 config use the underlying models batch implementation.
|
# If <= 1 config use the underlying models batch implementation.
|
||||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||||
@ -808,8 +802,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> Iterator[Any]:
|
) -> Iterator[Any]:
|
||||||
for x in self._model(config).transform(input, config=config, **kwargs):
|
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||||
yield x
|
|
||||||
|
|
||||||
async def atransform(
|
async def atransform(
|
||||||
self,
|
self,
|
||||||
@ -915,13 +908,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
|||||||
# Explicitly added to satisfy downstream linters.
|
# Explicitly added to satisfy downstream linters.
|
||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||||
|
|
||||||
# Explicitly added to satisfy downstream linters.
|
# Explicitly added to satisfy downstream linters.
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
from importlib import util
|
from importlib import util
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain_core._api import beta
|
from langchain_core._api import beta
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -25,7 +25,7 @@ def _get_provider_list() -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_model_string(model_name: str) -> Tuple[str, str]:
|
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||||
"""Parse a model string into provider and model name components.
|
"""Parse a model string into provider and model name components.
|
||||||
|
|
||||||
The model string should be in the format 'provider:model-name', where provider
|
The model string should be in the format 'provider:model-name', where provider
|
||||||
@ -78,7 +78,7 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]:
|
|||||||
|
|
||||||
def _infer_model_and_provider(
|
def _infer_model_and_provider(
|
||||||
model: str, *, provider: Optional[str] = None
|
model: str, *, provider: Optional[str] = None
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
if not model.strip():
|
if not model.strip():
|
||||||
raise ValueError("Model name cannot be empty")
|
raise ValueError("Model name cannot be empty")
|
||||||
if provider is None and ":" in model:
|
if provider is None and ":" in model:
|
||||||
@ -122,7 +122,7 @@ def init_embeddings(
|
|||||||
*,
|
*,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Union[Embeddings, Runnable[Any, List[float]]]:
|
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||||
"""Initialize an embeddings model from a model name and optional provider.
|
"""Initialize an embeddings model from a model name and optional provider.
|
||||||
|
|
||||||
**Note:** Must have the integration package corresponding to the model provider
|
**Note:** Must have the integration package corresponding to the model provider
|
||||||
|
@ -12,8 +12,9 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Sequence, Union, cast
|
from typing import Callable, Optional, Union, cast
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.stores import BaseStore, ByteStore
|
from langchain_core.stores import BaseStore, ByteStore
|
||||||
@ -45,9 +46,9 @@ def _value_serializer(value: Sequence[float]) -> bytes:
|
|||||||
return json.dumps(value).encode()
|
return json.dumps(value).encode()
|
||||||
|
|
||||||
|
|
||||||
def _value_deserializer(serialized_value: bytes) -> List[float]:
|
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||||
"""Deserialize a value."""
|
"""Deserialize a value."""
|
||||||
return cast(List[float], json.loads(serialized_value.decode()))
|
return cast(list[float], json.loads(serialized_value.decode()))
|
||||||
|
|
||||||
|
|
||||||
class CacheBackedEmbeddings(Embeddings):
|
class CacheBackedEmbeddings(Embeddings):
|
||||||
@ -88,10 +89,10 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
underlying_embeddings: Embeddings,
|
underlying_embeddings: Embeddings,
|
||||||
document_embedding_store: BaseStore[str, List[float]],
|
document_embedding_store: BaseStore[str, list[float]],
|
||||||
*,
|
*,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
|
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the embedder.
|
"""Initialize the embedder.
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
self.underlying_embeddings = underlying_embeddings
|
self.underlying_embeddings = underlying_embeddings
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of texts.
|
"""Embed a list of texts.
|
||||||
|
|
||||||
The method first checks the cache for the embeddings.
|
The method first checks the cache for the embeddings.
|
||||||
@ -121,10 +122,10 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of embeddings for the given texts.
|
A list of embeddings for the given texts.
|
||||||
"""
|
"""
|
||||||
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||||
texts
|
texts
|
||||||
)
|
)
|
||||||
all_missing_indices: List[int] = [
|
all_missing_indices: list[int] = [
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
i for i, vector in enumerate(vectors) if vector is None
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -138,10 +139,10 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
List[List[float]], vectors
|
list[list[float]], vectors
|
||||||
) # Nones should have been resolved by now
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Embed a list of texts.
|
"""Embed a list of texts.
|
||||||
|
|
||||||
The method first checks the cache for the embeddings.
|
The method first checks the cache for the embeddings.
|
||||||
@ -154,10 +155,10 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of embeddings for the given texts.
|
A list of embeddings for the given texts.
|
||||||
"""
|
"""
|
||||||
vectors: List[
|
vectors: list[
|
||||||
Union[List[float], None]
|
Union[list[float], None]
|
||||||
] = await self.document_embedding_store.amget(texts)
|
] = await self.document_embedding_store.amget(texts)
|
||||||
all_missing_indices: List[int] = [
|
all_missing_indices: list[int] = [
|
||||||
i for i, vector in enumerate(vectors) if vector is None
|
i for i, vector in enumerate(vectors) if vector is None
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -175,10 +176,10 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
vectors[index] = updated_vector
|
vectors[index] = updated_vector
|
||||||
|
|
||||||
return cast(
|
return cast(
|
||||||
List[List[float]], vectors
|
list[list[float]], vectors
|
||||||
) # Nones should have been resolved by now
|
) # Nones should have been resolved by now
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
By default, this method does not cache queries. To enable caching, set the
|
By default, this method does not cache queries. To enable caching, set the
|
||||||
@ -201,7 +202,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
self.query_embedding_store.mset([(text, vector)])
|
self.query_embedding_store.mset([(text, vector)])
|
||||||
return vector
|
return vector
|
||||||
|
|
||||||
async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> list[float]:
|
||||||
"""Embed query text.
|
"""Embed query text.
|
||||||
|
|
||||||
By default, this method does not cache queries. To enable caching, set the
|
By default, this method does not cache queries. To enable caching, set the
|
||||||
@ -250,7 +251,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
"""
|
"""
|
||||||
namespace = namespace
|
namespace = namespace
|
||||||
key_encoder = _create_key_encoder(namespace)
|
key_encoder = _create_key_encoder(namespace)
|
||||||
document_embedding_store = EncoderBackedStore[str, List[float]](
|
document_embedding_store = EncoderBackedStore[str, list[float]](
|
||||||
document_embedding_cache,
|
document_embedding_cache,
|
||||||
key_encoder,
|
key_encoder,
|
||||||
_value_serializer,
|
_value_serializer,
|
||||||
@ -261,7 +262,7 @@ class CacheBackedEmbeddings(Embeddings):
|
|||||||
elif query_embedding_cache is False:
|
elif query_embedding_cache is False:
|
||||||
query_embedding_store = None
|
query_embedding_store = None
|
||||||
else:
|
else:
|
||||||
query_embedding_store = EncoderBackedStore[str, List[float]](
|
query_embedding_store = EncoderBackedStore[str, list[float]](
|
||||||
query_embedding_cache,
|
query_embedding_cache,
|
||||||
key_encoder,
|
key_encoder,
|
||||||
_value_serializer,
|
_value_serializer,
|
||||||
|
@ -6,13 +6,10 @@ chain (LLMChain) to generate the reasoning and scores.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
@ -145,7 +142,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
|||||||
# 0
|
# 0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_tools: Optional[List[BaseTool]] = None
|
agent_tools: Optional[list[BaseTool]] = None
|
||||||
"""A list of tools available to the agent."""
|
"""A list of tools available to the agent."""
|
||||||
eval_chain: LLMChain
|
eval_chain: LLMChain
|
||||||
"""The language model chain used for evaluation."""
|
"""The language model chain used for evaluation."""
|
||||||
@ -184,7 +181,7 @@ Description: {tool.description}"""
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_agent_trajectory(
|
def get_agent_trajectory(
|
||||||
steps: Union[str, Sequence[Tuple[AgentAction, str]]],
|
steps: Union[str, Sequence[tuple[AgentAction, str]]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get the agent trajectory as a formatted string.
|
"""Get the agent trajectory as a formatted string.
|
||||||
|
|
||||||
@ -263,7 +260,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> list[str]:
|
||||||
"""Get the input keys for the chain.
|
"""Get the input keys for the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -272,7 +269,7 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
return ["question", "agent_trajectory", "answer", "reference"]
|
return ["question", "agent_trajectory", "answer", "reference"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> list[str]:
|
||||||
"""Get the output keys for the chain.
|
"""Get the output keys for the chain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -280,16 +277,16 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
"""
|
"""
|
||||||
return ["score", "reasoning"]
|
return ["score", "reasoning"]
|
||||||
|
|
||||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||||
"""Validate and prep inputs."""
|
"""Validate and prep inputs."""
|
||||||
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
||||||
return super().prep_inputs(inputs)
|
return super().prep_inputs(inputs)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run the chain and generate the output.
|
"""Run the chain and generate the output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -311,9 +308,9 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
|
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: dict[str, str],
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run the chain and generate the output.
|
"""Run the chain and generate the output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -338,11 +335,11 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
*,
|
*,
|
||||||
prediction: str,
|
prediction: str,
|
||||||
input: str,
|
input: str,
|
||||||
agent_trajectory: Sequence[Tuple[AgentAction, str]],
|
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||||
reference: Optional[str] = None,
|
reference: Optional[str] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -380,11 +377,11 @@ The following is the expected answer. Use this to measure correctness:
|
|||||||
*,
|
*,
|
||||||
prediction: str,
|
prediction: str,
|
||||||
input: str,
|
input: str,
|
||||||
agent_trajectory: Sequence[Tuple[AgentAction, str]],
|
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||||
reference: Optional[str] = None,
|
reference: Optional[str] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from langchain_core.callbacks.manager import Callbacks
|
from langchain_core.callbacks.manager import Callbacks
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -49,7 +49,7 @@ _SUPPORTED_CRITERIA = {
|
|||||||
|
|
||||||
|
|
||||||
def resolve_pairwise_criteria(
|
def resolve_pairwise_criteria(
|
||||||
criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]],
|
criteria: Optional[Union[CRITERIA_TYPE, str, list[CRITERIA_TYPE]]],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Resolve the criteria for the pairwise evaluator.
|
"""Resolve the criteria for the pairwise evaluator.
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): # type: ignore[
|
|||||||
"""
|
"""
|
||||||
return "pairwise_string_result"
|
return "pairwise_string_result"
|
||||||
|
|
||||||
def parse(self, text: str) -> Dict[str, Any]:
|
def parse(self, text: str) -> dict[str, Any]:
|
||||||
"""Parse the output text.
|
"""Parse the output text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -314,8 +314,8 @@ Performance may be significantly worse with other models."
|
|||||||
input: Optional[str] = None,
|
input: Optional[str] = None,
|
||||||
reference: Optional[str] = None,
|
reference: Optional[str] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -356,8 +356,8 @@ Performance may be significantly worse with other models."
|
|||||||
reference: Optional[str] = None,
|
reference: Optional[str] = None,
|
||||||
input: Optional[str] = None,
|
input: Optional[str] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user