langchain[lint]: use pyupgrade to get to 3.9 standards (#30782)

This commit is contained in:
Sydney Runkle 2025-04-11 10:33:26 -04:00 committed by GitHub
parent d9b628e764
commit 48affc498b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
214 changed files with 1530 additions and 1554 deletions

View File

@ -1,5 +1,5 @@
import importlib import importlib
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Optional
from langchain_core._api import internal, warn_deprecated from langchain_core._api import internal, warn_deprecated
@ -15,8 +15,8 @@ ALLOWED_TOP_LEVEL_PKGS = {
def create_importer( def create_importer(
package: str, package: str,
*, *,
module_lookup: Optional[Dict[str, str]] = None, module_lookup: Optional[dict[str, str]] = None,
deprecated_lookups: Optional[Dict[str, str]] = None, deprecated_lookups: Optional[dict[str, str]] = None,
fallback_module: Optional[str] = None, fallback_module: Optional[str] = None,
) -> Callable[[str], Any]: ) -> Callable[[str], Any]:
"""Create a function that helps retrieve objects from their new locations. """Create a function that helps retrieve objects from their new locations.

View File

@ -3,21 +3,17 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import builtins
import json import json
import logging import logging
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Optional, Optional,
Sequence,
Tuple,
Union, Union,
cast, cast,
) )
@ -62,17 +58,17 @@ class BaseSingleActionAgent(BaseModel):
"""Base Single Action Agent class.""" """Base Single Action Agent class."""
@property @property
def return_values(self) -> List[str]: def return_values(self) -> list[str]:
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[list[str]]:
return None return None
@abstractmethod @abstractmethod
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -91,7 +87,7 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod @abstractmethod
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -109,7 +105,7 @@ class BaseSingleActionAgent(BaseModel):
@property @property
@abstractmethod @abstractmethod
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
:meta private: :meta private:
@ -118,7 +114,7 @@ class BaseSingleActionAgent(BaseModel):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -171,7 +167,7 @@ class BaseSingleActionAgent(BaseModel):
"""Return Identifier of an agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent. """Return dictionary representation of agent.
Returns: Returns:
@ -223,7 +219,7 @@ class BaseSingleActionAgent(BaseModel):
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run.""" """Return logging kwargs for tool run."""
return {} return {}
@ -232,11 +228,11 @@ class BaseMultiActionAgent(BaseModel):
"""Base Multi Action Agent class.""" """Base Multi Action Agent class."""
@property @property
def return_values(self) -> List[str]: def return_values(self) -> list[str]:
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[list[str]]:
"""Get allowed tools. """Get allowed tools.
Returns: Returns:
@ -247,10 +243,10 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod @abstractmethod
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
@ -266,10 +262,10 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod @abstractmethod
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Async given input, decided what to do. """Async given input, decided what to do.
Args: Args:
@ -284,7 +280,7 @@ class BaseMultiActionAgent(BaseModel):
@property @property
@abstractmethod @abstractmethod
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
:meta private: :meta private:
@ -293,7 +289,7 @@ class BaseMultiActionAgent(BaseModel):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -323,7 +319,7 @@ class BaseMultiActionAgent(BaseModel):
"""Return Identifier of an agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().model_dump() _dict = super().model_dump()
try: try:
@ -371,7 +367,7 @@ class BaseMultiActionAgent(BaseModel):
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run.""" """Return logging kwargs for tool run."""
return {} return {}
@ -386,7 +382,7 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
class MultiActionAgentOutputParser( class MultiActionAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]] BaseOutputParser[Union[list[AgentAction], AgentFinish]]
): ):
"""Base class for parsing agent output into agent actions/finish. """Base class for parsing agent output into agent actions/finish.
@ -394,7 +390,7 @@ class MultiActionAgentOutputParser(
""" """
@abstractmethod @abstractmethod
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
"""Parse text into agent actions/finish. """Parse text into agent actions/finish.
Args: Args:
@ -411,8 +407,8 @@ class RunnableAgent(BaseSingleActionAgent):
runnable: Runnable[dict, Union[AgentAction, AgentFinish]] runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action.""" """Runnable to call to get agent action."""
input_keys_arg: List[str] = [] input_keys_arg: list[str] = []
return_keys_arg: List[str] = [] return_keys_arg: list[str] = []
stream_runnable: bool = True stream_runnable: bool = True
"""Whether to stream from the runnable or not. """Whether to stream from the runnable or not.
@ -427,18 +423,18 @@ class RunnableAgent(BaseSingleActionAgent):
) )
@property @property
def return_values(self) -> List[str]: def return_values(self) -> list[str]:
"""Return values of the agent.""" """Return values of the agent."""
return self.return_keys_arg return self.return_keys_arg
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys.""" """Return the input keys."""
return self.input_keys_arg return self.input_keys_arg
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -474,7 +470,7 @@ class RunnableAgent(BaseSingleActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[ ) -> Union[
@ -518,10 +514,10 @@ class RunnableAgent(BaseSingleActionAgent):
class RunnableMultiActionAgent(BaseMultiActionAgent): class RunnableMultiActionAgent(BaseMultiActionAgent):
"""Agent powered by Runnables.""" """Agent powered by Runnables."""
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]] runnable: Runnable[dict, Union[list[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions.""" """Runnable to call to get agent actions."""
input_keys_arg: List[str] = [] input_keys_arg: list[str] = []
return_keys_arg: List[str] = [] return_keys_arg: list[str] = []
stream_runnable: bool = True stream_runnable: bool = True
"""Whether to stream from the runnable or not. """Whether to stream from the runnable or not.
@ -536,12 +532,12 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
) )
@property @property
def return_values(self) -> List[str]: def return_values(self) -> list[str]:
"""Return values of the agent.""" """Return values of the agent."""
return self.return_keys_arg return self.return_keys_arg
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
Returns: Returns:
@ -551,11 +547,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[ ) -> Union[
List[AgentAction], list[AgentAction],
AgentFinish, AgentFinish,
]: ]:
"""Based on past history and current inputs, decide what to do. """Based on past history and current inputs, decide what to do.
@ -590,11 +586,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[ ) -> Union[
List[AgentAction], list[AgentAction],
AgentFinish, AgentFinish,
]: ]:
"""Async based on past history and current inputs, decide what to do. """Async based on past history and current inputs, decide what to do.
@ -644,11 +640,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
"""LLMChain to use for agent.""" """LLMChain to use for agent."""
output_parser: AgentOutputParser output_parser: AgentOutputParser
"""Output parser to use for agent.""" """Output parser to use for agent."""
stop: List[str] stop: list[str]
"""List of strings to stop on.""" """List of strings to stop on."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
Returns: Returns:
@ -656,7 +652,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
""" """
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
del _dict["output_parser"] del _dict["output_parser"]
@ -664,7 +660,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -689,7 +685,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -712,7 +708,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
) )
return self.output_parser.parse(output) return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run.""" """Return logging kwargs for tool run."""
return { return {
"llm_prefix": "", "llm_prefix": "",
@ -737,21 +733,21 @@ class Agent(BaseSingleActionAgent):
"""LLMChain to use for agent.""" """LLMChain to use for agent."""
output_parser: AgentOutputParser output_parser: AgentOutputParser
"""Output parser to use for agent.""" """Output parser to use for agent."""
allowed_tools: Optional[List[str]] = None allowed_tools: Optional[list[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed.""" """Allowed tools for the agent. If None, all tools are allowed."""
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
del _dict["output_parser"] del _dict["output_parser"]
return _dict return _dict
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[list[str]]:
"""Get allowed tools.""" """Get allowed tools."""
return self.allowed_tools return self.allowed_tools
@property @property
def return_values(self) -> List[str]: def return_values(self) -> list[str]:
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
@ -767,15 +763,15 @@ class Agent(BaseSingleActionAgent):
raise ValueError("fix_text not implemented for this agent.") raise ValueError("fix_text not implemented for this agent.")
@property @property
def _stop(self) -> List[str]: def _stop(self) -> list[str]:
return [ return [
f"\n{self.observation_prefix.rstrip()}", f"\n{self.observation_prefix.rstrip()}",
f"\n\t{self.observation_prefix.rstrip()}", f"\n\t{self.observation_prefix.rstrip()}",
] ]
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> Union[str, List[BaseMessage]]: ) -> Union[str, list[BaseMessage]]:
"""Construct the scratchpad that lets the agent continue its thought process.""" """Construct the scratchpad that lets the agent continue its thought process."""
thoughts = "" thoughts = ""
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
@ -785,7 +781,7 @@ class Agent(BaseSingleActionAgent):
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -806,7 +802,7 @@ class Agent(BaseSingleActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -827,8 +823,8 @@ class Agent(BaseSingleActionAgent):
return agent_output return agent_output
def get_full_inputs( def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
) -> Dict[str, Any]: ) -> builtins.dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps. """Create the full inputs for the LLMChain from intermediate steps.
Args: Args:
@ -845,7 +841,7 @@ class Agent(BaseSingleActionAgent):
return full_inputs return full_inputs
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
:meta private: :meta private:
@ -957,7 +953,7 @@ class Agent(BaseSingleActionAgent):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -1009,7 +1005,7 @@ class Agent(BaseSingleActionAgent):
f"got {early_stopping_method}" f"got {early_stopping_method}"
) )
def tool_run_logging_kwargs(self) -> Dict: def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run.""" """Return logging kwargs for tool run."""
return { return {
"llm_prefix": self.llm_prefix, "llm_prefix": self.llm_prefix,
@ -1040,7 +1036,7 @@ class ExceptionTool(BaseTool): # type: ignore[override]
return query return query
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]] NextStepOutput = list[Union[AgentFinish, AgentAction, AgentStep]]
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent] RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
@ -1086,7 +1082,7 @@ class AgentExecutor(Chain):
as an observation. as an observation.
""" """
trim_intermediate_steps: Union[ trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]] int, Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]]
] = -1 ] = -1
"""How to trim the intermediate steps before returning them. """How to trim the intermediate steps before returning them.
Defaults to -1, which means no trimming. Defaults to -1, which means no trimming.
@ -1144,7 +1140,7 @@ class AgentExecutor(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_runnable_agent(cls, values: Dict) -> Any: def validate_runnable_agent(cls, values: dict) -> Any:
"""Convert runnable to agent if passed in. """Convert runnable to agent if passed in.
Args: Args:
@ -1160,7 +1156,7 @@ class AgentExecutor(Chain):
except Exception as _: except Exception as _:
multi_action = False multi_action = False
else: else:
multi_action = output_type == Union[List[AgentAction], AgentFinish] multi_action = output_type == Union[list[AgentAction], AgentFinish]
stream_runnable = values.pop("stream_runnable", True) stream_runnable = values.pop("stream_runnable", True)
if multi_action: if multi_action:
@ -1239,7 +1235,7 @@ class AgentExecutor(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the input keys. """Return the input keys.
:meta private: :meta private:
@ -1247,7 +1243,7 @@ class AgentExecutor(Chain):
return self._action_agent.input_keys return self._action_agent.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return the singular output key. """Return the singular output key.
:meta private: :meta private:
@ -1284,7 +1280,7 @@ class AgentExecutor(Chain):
output: AgentFinish, output: AgentFinish,
intermediate_steps: list, intermediate_steps: list,
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if run_manager: if run_manager:
run_manager.on_agent_finish(output, color="green", verbose=self.verbose) run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values final_output = output.return_values
@ -1297,7 +1293,7 @@ class AgentExecutor(Chain):
output: AgentFinish, output: AgentFinish,
intermediate_steps: list, intermediate_steps: list,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if run_manager: if run_manager:
await run_manager.on_agent_finish( await run_manager.on_agent_finish(
output, color="green", verbose=self.verbose output, color="green", verbose=self.verbose
@ -1309,7 +1305,7 @@ class AgentExecutor(Chain):
def _consume_next_step( def _consume_next_step(
self, values: NextStepOutput self, values: NextStepOutput
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: ) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
if isinstance(values[-1], AgentFinish): if isinstance(values[-1], AgentFinish):
assert len(values) == 1 assert len(values) == 1
return values[-1] return values[-1]
@ -1320,12 +1316,12 @@ class AgentExecutor(Chain):
def _take_next_step( def _take_next_step(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
inputs: Dict[str, str], inputs: dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: ) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step( return self._consume_next_step(
[ [
a a
@ -1341,10 +1337,10 @@ class AgentExecutor(Chain):
def _iter_next_step( def _iter_next_step(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
inputs: Dict[str, str], inputs: dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]: ) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop. """Take a single step in the thought-action-observation loop.
@ -1404,7 +1400,7 @@ class AgentExecutor(Chain):
yield output yield output
return return
actions: List[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
actions = [output] actions = [output]
else: else:
@ -1418,8 +1414,8 @@ class AgentExecutor(Chain):
def _perform_agent_action( def _perform_agent_action(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
agent_action: AgentAction, agent_action: AgentAction,
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> AgentStep: ) -> AgentStep:
@ -1457,12 +1453,12 @@ class AgentExecutor(Chain):
async def _atake_next_step( async def _atake_next_step(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
inputs: Dict[str, str], inputs: dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: ) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step( return self._consume_next_step(
[ [
a a
@ -1478,10 +1474,10 @@ class AgentExecutor(Chain):
async def _aiter_next_step( async def _aiter_next_step(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
inputs: Dict[str, str], inputs: dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]: ) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop. """Take a single step in the thought-action-observation loop.
@ -1539,7 +1535,7 @@ class AgentExecutor(Chain):
yield output yield output
return return
actions: List[AgentAction] actions: list[AgentAction]
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
actions = [output] actions = [output]
else: else:
@ -1563,8 +1559,8 @@ class AgentExecutor(Chain):
async def _aperform_agent_action( async def _aperform_agent_action(
self, self,
name_to_tool_map: Dict[str, BaseTool], name_to_tool_map: dict[str, BaseTool],
color_mapping: Dict[str, str], color_mapping: dict[str, str],
agent_action: AgentAction, agent_action: AgentAction,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AgentStep: ) -> AgentStep:
@ -1604,9 +1600,9 @@ class AgentExecutor(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run text through and get agent response.""" """Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup # Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools} name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -1614,7 +1610,7 @@ class AgentExecutor(Chain):
color_mapping = get_color_mapping( color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green", "red"] [tool.name for tool in self.tools], excluded_colors=["green", "red"]
) )
intermediate_steps: List[Tuple[AgentAction, str]] = [] intermediate_steps: list[tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed # Let's start tracking the number of iterations and time elapsed
iterations = 0 iterations = 0
time_elapsed = 0.0 time_elapsed = 0.0
@ -1651,9 +1647,9 @@ class AgentExecutor(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Async run text through and get agent response.""" """Async run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup # Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools} name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -1661,7 +1657,7 @@ class AgentExecutor(Chain):
color_mapping = get_color_mapping( color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"] [tool.name for tool in self.tools], excluded_colors=["green"]
) )
intermediate_steps: List[Tuple[AgentAction, str]] = [] intermediate_steps: list[tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed # Let's start tracking the number of iterations and time elapsed
iterations = 0 iterations = 0
time_elapsed = 0.0 time_elapsed = 0.0
@ -1712,7 +1708,7 @@ class AgentExecutor(Chain):
) )
def _get_tool_return( def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str] self, next_step_output: tuple[AgentAction, str]
) -> Optional[AgentFinish]: ) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool.""" """Check if the tool is a returning tool."""
agent_action, observation = next_step_output agent_action, observation = next_step_output
@ -1730,8 +1726,8 @@ class AgentExecutor(Chain):
return None return None
def _prepare_intermediate_steps( def _prepare_intermediate_steps(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]: ) -> list[tuple[AgentAction, str]]:
if ( if (
isinstance(self.trim_intermediate_steps, int) isinstance(self.trim_intermediate_steps, int)
and self.trim_intermediate_steps > 0 and self.trim_intermediate_steps > 0
@ -1744,7 +1740,7 @@ class AgentExecutor(Chain):
def stream( def stream(
self, self,
input: Union[Dict[str, Any], Any], input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[AddableDict]: ) -> Iterator[AddableDict]:
@ -1770,12 +1766,11 @@ class AgentExecutor(Chain):
yield_actions=True, yield_actions=True,
**kwargs, **kwargs,
) )
for step in iterator: yield from iterator
yield step
async def astream( async def astream(
self, self,
input: Union[Dict[str, Any], Any], input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[AddableDict]: ) -> AsyncIterator[AddableDict]:

View File

@ -3,15 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import AsyncIterator, Iterator
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional, Optional,
Tuple,
Union, Union,
) )
from uuid import UUID from uuid import UUID
@ -53,7 +49,7 @@ class AgentExecutorIterator:
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
run_id: Optional[UUID] = None, run_id: Optional[UUID] = None,
include_run_info: bool = False, include_run_info: bool = False,
@ -90,17 +86,17 @@ class AgentExecutorIterator:
self.yield_actions = yield_actions self.yield_actions = yield_actions
self.reset() self.reset()
_inputs: Dict[str, str] _inputs: dict[str, str]
callbacks: Callbacks callbacks: Callbacks
tags: Optional[list[str]] tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]] metadata: Optional[dict[str, Any]]
run_name: Optional[str] run_name: Optional[str]
run_id: Optional[UUID] run_id: Optional[UUID]
include_run_info: bool include_run_info: bool
yield_actions: bool yield_actions: bool
@property @property
def inputs(self) -> Dict[str, str]: def inputs(self) -> dict[str, str]:
"""The inputs to the AgentExecutor.""" """The inputs to the AgentExecutor."""
return self._inputs return self._inputs
@ -120,12 +116,12 @@ class AgentExecutorIterator:
self.inputs = self.inputs self.inputs = self.inputs
@property @property
def name_to_tool_map(self) -> Dict[str, BaseTool]: def name_to_tool_map(self) -> dict[str, BaseTool]:
"""A mapping of tool names to tools.""" """A mapping of tool names to tools."""
return {tool.name: tool for tool in self.agent_executor.tools} return {tool.name: tool for tool in self.agent_executor.tools}
@property @property
def color_mapping(self) -> Dict[str, str]: def color_mapping(self) -> dict[str, str]:
"""A mapping of tool names to colors.""" """A mapping of tool names to colors."""
return get_color_mapping( return get_color_mapping(
[tool.name for tool in self.agent_executor.tools], [tool.name for tool in self.agent_executor.tools],
@ -156,7 +152,7 @@ class AgentExecutorIterator:
def make_final_outputs( def make_final_outputs(
self, self,
outputs: Dict[str, Any], outputs: dict[str, Any],
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun], run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun],
) -> AddableDict: ) -> AddableDict:
# have access to intermediate steps by design in iterator, # have access to intermediate steps by design in iterator,
@ -171,7 +167,7 @@ class AgentExecutorIterator:
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return prepared_outputs return prepared_outputs
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]: def __iter__(self: AgentExecutorIterator) -> Iterator[AddableDict]:
logger.debug("Initialising AgentExecutorIterator") logger.debug("Initialising AgentExecutorIterator")
self.reset() self.reset()
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
@ -311,7 +307,7 @@ class AgentExecutorIterator:
def _process_next_step_output( def _process_next_step_output(
self, self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> AddableDict: ) -> AddableDict:
""" """
@ -339,7 +335,7 @@ class AgentExecutorIterator:
async def _aprocess_next_step_output( async def _aprocess_next_step_output(
self, self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> AddableDict: ) -> AddableDict:
""" """

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
@ -26,7 +26,7 @@ def _get_default_system_message() -> SystemMessage:
def create_conversational_retrieval_agent( def create_conversational_retrieval_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: List[BaseTool], tools: list[BaseTool],
remember_intermediate_steps: bool = True, remember_intermediate_steps: bool = True,
memory_key: str = "chat_history", memory_key: str = "chat_history",
system_message: Optional[SystemMessage] = None, system_message: Optional[SystemMessage] = None,

View File

@ -1,6 +1,6 @@
"""VectorStore agent.""" """VectorStore agent."""
from typing import Any, Dict, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks.base import BaseCallbackManager from langchain_core.callbacks.base import BaseCallbackManager
@ -36,7 +36,7 @@ def create_vectorstore_agent(
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX, prefix: str = PREFIX,
verbose: bool = False, verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None, agent_executor_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct a VectorStore agent from an LLM and tools. """Construct a VectorStore agent from an LLM and tools.
@ -129,7 +129,7 @@ def create_vectorstore_router_agent(
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = ROUTER_PREFIX, prefix: str = ROUTER_PREFIX,
verbose: bool = False, verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None, agent_executor_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Construct a VectorStore router agent from an LLM and tools. """Construct a VectorStore router agent from an LLM and tools.

View File

@ -1,7 +1,5 @@
"""Toolkit for interacting with a vector store.""" """Toolkit for interacting with a vector store."""
from typing import List
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit from langchain_core.tools.base import BaseToolkit
@ -31,7 +29,7 @@ class VectorStoreToolkit(BaseToolkit):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
try: try:
from langchain_community.tools.vectorstore.tool import ( from langchain_community.tools.vectorstore.tool import (
@ -66,16 +64,16 @@ class VectorStoreToolkit(BaseToolkit):
class VectorStoreRouterToolkit(BaseToolkit): class VectorStoreRouterToolkit(BaseToolkit):
"""Toolkit for routing between Vector Stores.""" """Toolkit for routing between Vector Stores."""
vectorstores: List[VectorStoreInfo] = Field(exclude=True) vectorstores: list[VectorStoreInfo] = Field(exclude=True)
llm: BaseLanguageModel llm: BaseLanguageModel
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
def get_tools(self) -> List[BaseTool]: def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit.""" """Get the tools in the toolkit."""
tools: List[BaseTool] = [] tools: list[BaseTool] = []
try: try:
from langchain_community.tools.vectorstore.tool import ( from langchain_community.tools.vectorstore.tool import (
VectorStoreQATool, VectorStoreQATool,

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
@ -48,7 +49,7 @@ class ChatAgent(Agent):
return "Thought:" return "Thought:"
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str: ) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps) agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str): if not isinstance(agent_scratchpad, str):
@ -72,7 +73,7 @@ class ChatAgent(Agent):
validate_tools_single_input(class_name=cls.__name__, tools=tools) validate_tools_single_input(class_name=cls.__name__, tools=tools)
@property @property
def _stop(self) -> List[str]: def _stop(self) -> list[str]:
return ["Observation:"] return ["Observation:"]
@classmethod @classmethod
@ -83,7 +84,7 @@ class ChatAgent(Agent):
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
human_message: str = HUMAN_MESSAGE, human_message: str = HUMAN_MESSAGE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Create a prompt from a list of tools. """Create a prompt from a list of tools.
@ -132,7 +133,7 @@ class ChatAgent(Agent):
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX, system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
human_message: str = HUMAN_MESSAGE, human_message: str = HUMAN_MESSAGE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools. """Construct an agent from an LLM and tools.

View File

@ -1,6 +1,7 @@
import json import json
import re import re
from typing import Pattern, Union from re import Pattern
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Sequence from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager from langchain_core.callbacks import BaseCallbackManager
@ -71,7 +72,7 @@ class ConversationalAgent(Agent):
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
ai_prefix: str = "AI", ai_prefix: str = "AI",
human_prefix: str = "Human", human_prefix: str = "Human",
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
) -> PromptTemplate: ) -> PromptTemplate:
"""Create prompt in the style of the zero-shot agent. """Create prompt in the style of the zero-shot agent.
@ -120,7 +121,7 @@ class ConversationalAgent(Agent):
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
ai_prefix: str = "AI", ai_prefix: str = "AI",
human_prefix: str = "Human", human_prefix: str = "Human",
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools. """Construct an agent from an LLM and tools.

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
@ -77,7 +78,7 @@ class ConversationalChatAgent(Agent):
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
system_message: str = PREFIX, system_message: str = PREFIX,
human_message: str = SUFFIX, human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
output_parser: Optional[BaseOutputParser] = None, output_parser: Optional[BaseOutputParser] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Create a prompt for the agent. """Create a prompt for the agent.
@ -116,10 +117,10 @@ class ConversationalChatAgent(Agent):
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type] return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process.""" """Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = [] thoughts: list[BaseMessage] = []
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log)) thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage( human_message = HumanMessage(
@ -137,7 +138,7 @@ class ConversationalChatAgent(Agent):
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
system_message: str = PREFIX, system_message: str = PREFIX,
human_message: str = SUFFIX, human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools. """Construct an agent from an LLM and tools.

View File

@ -1,10 +1,8 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
def format_log_to_str( def format_log_to_str(
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
observation_prefix: str = "Observation: ", observation_prefix: str = "Observation: ",
llm_prefix: str = "Thought: ", llm_prefix: str = "Thought: ",
) -> str: ) -> str:

View File

@ -1,13 +1,11 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
def format_log_to_messages( def format_log_to_messages(
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
template_tool_response: str = "{observation}", template_tool_response: str = "{observation}",
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process. """Construct the scratchpad that lets the agent continue its thought process.
Args: Args:
@ -18,7 +16,7 @@ def format_log_to_messages(
Returns: Returns:
List[BaseMessage]: The scratchpad. List[BaseMessage]: The scratchpad.
""" """
thoughts: List[BaseMessage] = [] thoughts: list[BaseMessage] = []
for action, observation in intermediate_steps: for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log)) thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage( human_message = HumanMessage(

View File

@ -1,5 +1,5 @@
import json import json
from typing import List, Sequence, Tuple from collections.abc import Sequence
from langchain_core.agents import AgentAction, AgentActionMessageLog from langchain_core.agents import AgentAction, AgentActionMessageLog
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
def _convert_agent_action_to_messages( def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str agent_action: AgentAction, observation: str
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Convert an agent action to a message. """Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action. This code is used to reconstruct the original AI message from the agent action.
@ -54,8 +54,8 @@ def _create_function_message(
def format_to_openai_function_messages( def format_to_openai_function_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]], intermediate_steps: Sequence[tuple[AgentAction, str]],
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages. """Convert (AgentAction, tool output) tuples into FunctionMessages.
Args: Args:

View File

@ -1,5 +1,5 @@
import json import json
from typing import List, Sequence, Tuple from collections.abc import Sequence
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.messages import ( from langchain_core.messages import (
@ -40,8 +40,8 @@ def _create_tool_message(
def format_to_tool_messages( def format_to_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]], intermediate_steps: Sequence[tuple[AgentAction, str]],
) -> List[BaseMessage]: ) -> list[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into ToolMessages. """Convert (AgentAction, tool output) tuples into ToolMessages.
Args: Args:

View File

@ -1,10 +1,8 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
def format_xml( def format_xml(
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
) -> str: ) -> str:
"""Format the intermediate steps as XML. """Format the intermediate steps as XML.

View File

@ -1,6 +1,7 @@
"""Load agent.""" """Load agent."""
from typing import Any, Optional, Sequence from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager from langchain_core.callbacks import BaseCallbackManager

View File

@ -1,4 +1,5 @@
from typing import List, Sequence, Union from collections.abc import Sequence
from typing import Union
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate
@ -15,7 +16,7 @@ def create_json_chat_agent(
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
prompt: ChatPromptTemplate, prompt: ChatPromptTemplate,
stop_sequence: Union[bool, List[str]] = True, stop_sequence: Union[bool, list[str]] = True,
tools_renderer: ToolsRenderer = render_text_description, tools_renderer: ToolsRenderer = render_text_description,
template_tool_response: str = TEMPLATE_TOOL_RESPONSE, template_tool_response: str = TEMPLATE_TOOL_RESPONSE,
) -> Runnable: ) -> Runnable:

View File

@ -3,7 +3,7 @@
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Union from typing import Any, Optional, Union
import yaml import yaml
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -20,7 +20,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
def _load_agent_from_tools( def _load_agent_from_tools(
config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any config: dict, llm: BaseLanguageModel, tools: list[Tool], **kwargs: Any
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
config_type = config.pop("_type") config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS: if config_type not in AGENT_TO_CLASS:
@ -35,7 +35,7 @@ def _load_agent_from_tools(
def load_agent_from_config( def load_agent_from_config(
config: dict, config: dict,
llm: Optional[BaseLanguageModel] = None, llm: Optional[BaseLanguageModel] = None,
tools: Optional[List[Tool]] = None, tools: Optional[list[Tool]] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Load agent from Config Dict. """Load agent from Config Dict.
@ -130,7 +130,7 @@ def _load_agent_from_file(
with open(file_path) as f: with open(file_path) as f:
config = json.load(f) config = json.load(f)
elif file_path.suffix[1:] == "yaml": elif file_path.suffix[1:] == "yaml":
with open(file_path, "r") as f: with open(file_path) as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
else: else:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Sequence from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager from langchain_core.callbacks import BaseCallbackManager
@ -83,7 +84,7 @@ class ZeroShotAgent(Agent):
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
) -> PromptTemplate: ) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent. """Create prompt in the style of the zero shot agent.
@ -118,7 +119,7 @@ class ZeroShotAgent(Agent):
prefix: str = PREFIX, prefix: str = PREFIX,
suffix: str = SUFFIX, suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools. """Construct an agent from an LLM and tools.
@ -183,7 +184,7 @@ class MRKLChain(AgentExecutor):
@classmethod @classmethod
def from_chains( def from_chains(
cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any cls, llm: BaseLanguageModel, chains: list[ChainConfig], **kwargs: Any
) -> AgentExecutor: ) -> AgentExecutor:
"""User-friendly way to initialize the MRKL chain. """User-friendly way to initialize the MRKL chain.

View File

@ -2,18 +2,14 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
from collections.abc import Sequence
from json import JSONDecodeError from json import JSONDecodeError
from time import sleep from time import sleep
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
Dict,
List,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
) )
@ -111,7 +107,7 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
def _is_assistants_builtin_tool( def _is_assistants_builtin_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
) -> bool: ) -> bool:
"""Determine if tool corresponds to OpenAI Assistants built-in.""" """Determine if tool corresponds to OpenAI Assistants built-in."""
assistants_builtin_tools = ("code_interpreter", "file_search") assistants_builtin_tools = ("code_interpreter", "file_search")
@ -123,8 +119,8 @@ def _is_assistants_builtin_tool(
def _get_assistants_tool( def _get_assistants_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool. """Convert a raw function/class to an OpenAI tool.
Note that OpenAI assistants supports several built-in tools, Note that OpenAI assistants supports several built-in tools,
@ -137,14 +133,14 @@ def _get_assistants_tool(
OutputType = Union[ OutputType = Union[
List[OpenAIAssistantAction], list[OpenAIAssistantAction],
OpenAIAssistantFinish, OpenAIAssistantFinish,
List["ThreadMessage"], list["ThreadMessage"],
List["RequiredActionFunctionToolCall"], list["RequiredActionFunctionToolCall"],
] ]
class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"""Run an OpenAI Assistant. """Run an OpenAI Assistant.
Example using OpenAI tools: Example using OpenAI tools:
@ -498,7 +494,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
return response return response
def _parse_intermediate_steps( def _parse_intermediate_steps(
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]] self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
) -> dict: ) -> dict:
last_action, last_output = intermediate_steps[-1] last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id) run = self._wait_for_run(last_action.run_id, last_action.thread_id)
@ -652,7 +648,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
return run return run
async def _aparse_intermediate_steps( async def _aparse_intermediate_steps(
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]] self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
) -> dict: ) -> dict:
last_action, last_output = intermediate_steps[-1] last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id) run = self._wait_for_run(last_action.run_id, last_action.thread_id)

View File

@ -1,6 +1,6 @@
"""Memory used to save agent output AND intermediate steps.""" """Memory used to save agent output AND intermediate steps."""
from typing import Any, Dict, List from typing import Any
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
@ -43,19 +43,19 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
format_as_tools: bool = False format_as_tools: bool = False
@property @property
def buffer(self) -> List[BaseMessage]: def buffer(self) -> list[BaseMessage]:
"""String buffer of memory.""" """String buffer of memory."""
return self.chat_memory.messages return self.chat_memory.messages
@property @property
def memory_variables(self) -> List[str]: def memory_variables(self) -> list[str]:
"""Always return list of memory variables. """Always return list of memory variables.
:meta private: :meta private:
""" """
return [self.memory_key] return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer. """Return history buffer.
Args: Args:
@ -74,7 +74,7 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
) )
return {self.memory_key: final_buffer} return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None: def save_context(self, inputs: dict[str, Any], outputs: dict[str, Any]) -> None:
"""Save context from this conversation to buffer. Pruned. """Save context from this conversation to buffer. Pruned.
Args: Args:

View File

@ -1,6 +1,7 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API.""" """Module implements an agent that uses OpenAI's APIs function enabled API."""
from typing import Any, List, Optional, Sequence, Tuple, Type, Union from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
@ -51,11 +52,11 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel llm: BaseLanguageModel
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
prompt: BasePromptTemplate prompt: BasePromptTemplate
output_parser: Type[OpenAIFunctionsAgentOutputParser] = ( output_parser: type[OpenAIFunctionsAgentOutputParser] = (
OpenAIFunctionsAgentOutputParser OpenAIFunctionsAgentOutputParser
) )
def get_allowed_tools(self) -> List[str]: def get_allowed_tools(self) -> list[str]:
"""Get allowed tools.""" """Get allowed tools."""
return [t.name for t in self.tools] return [t.name for t in self.tools]
@ -81,19 +82,19 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return self return self
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Get input keys. Input refers to user input here.""" """Get input keys. Input refers to user input here."""
return ["input"] return ["input"]
@property @property
def functions(self) -> List[dict]: def functions(self) -> list[dict]:
"""Get functions.""" """Get functions."""
return [dict(convert_to_openai_function(t)) for t in self.tools] return [dict(convert_to_openai_function(t)) for t in self.tools]
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
with_functions: bool = True, with_functions: bool = True,
**kwargs: Any, **kwargs: Any,
@ -135,7 +136,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -168,7 +169,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any, **kwargs: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -213,7 +214,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create prompt for this agent. """Create prompt for this agent.
@ -227,7 +228,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
A prompt template to pass into this agent. A prompt template to pass into this agent.
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: if system_message:
messages = [system_message] messages = [system_message]
else: else:
@ -248,7 +249,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),

View File

@ -1,8 +1,9 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API.""" """Module implements an agent that uses OpenAI's APIs function enabled API."""
import json import json
from collections.abc import Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
@ -34,7 +35,7 @@ from langchain.agents.format_scratchpad.openai_functions import (
_FunctionsAgentAction = AgentActionMessageLog _FunctionsAgentAction = AgentActionMessageLog
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]: def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message.""" """Parse an AI message."""
if not isinstance(message, AIMessage): if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}") raise TypeError(f"Expected an AI message got {type(message)}")
@ -58,7 +59,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
f"the `arguments` JSON does not contain `actions` key." f"the `arguments` JSON does not contain `actions` key."
) )
final_tools: List[AgentAction] = [] final_tools: list[AgentAction] = []
for tool_schema in tools: for tool_schema in tools:
if "action" in tool_schema: if "action" in tool_schema:
_tool_input = tool_schema["action"] _tool_input = tool_schema["action"]
@ -112,7 +113,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
prompt: BasePromptTemplate prompt: BasePromptTemplate
def get_allowed_tools(self) -> List[str]: def get_allowed_tools(self) -> list[str]:
"""Get allowed tools.""" """Get allowed tools."""
return [t.name for t in self.tools] return [t.name for t in self.tools]
@ -127,12 +128,12 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
return self return self
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Get input keys. Input refers to user input here.""" """Get input keys. Input refers to user input here."""
return ["input"] return ["input"]
@property @property
def functions(self) -> List[dict]: def functions(self) -> list[dict]:
"""Get the functions for the agent.""" """Get the functions for the agent."""
enum_vals = [t.name for t in self.tools] enum_vals = [t.name for t in self.tools]
tool_selection = { tool_selection = {
@ -194,10 +195,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Given input, decided what to do. """Given input, decided what to do.
Args: Args:
@ -224,10 +225,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Async given input, decided what to do. """Async given input, decided what to do.
Args: Args:
@ -258,7 +259,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
"""Create prompt for this agent. """Create prompt for this agent.
@ -272,7 +273,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
A prompt template to pass into this agent. A prompt template to pass into this agent.
""" """
_prompts = extra_prompt_messages or [] _prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message: if system_message:
messages = [system_message] messages = [system_message]
else: else:
@ -293,7 +294,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
llm: BaseLanguageModel, llm: BaseLanguageModel,
tools: Sequence[BaseTool], tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None, callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage( system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant." content="You are a helpful AI assistant."
), ),

View File

@ -1,4 +1,5 @@
from typing import Optional, Sequence from collections.abc import Sequence
from typing import Optional
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate

View File

@ -1,6 +1,6 @@
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import List, Union from typing import Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -77,7 +77,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
) )
def parse_result( def parse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output") raise ValueError("This output parser only works on ChatGeneration output")

View File

@ -1,4 +1,4 @@
from typing import List, Union from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@ -15,12 +15,12 @@ OpenAIToolAgentAction = ToolAgentAction
def parse_ai_message_to_openai_tool_action( def parse_ai_message_to_openai_tool_action(
message: BaseMessage, message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls.""" """Parse an AI message potentially containing tool_calls."""
tool_actions = parse_ai_message_to_tool_action(message) tool_actions = parse_ai_message_to_tool_action(message)
if isinstance(tool_actions, AgentFinish): if isinstance(tool_actions, AgentFinish):
return tool_actions return tool_actions
final_actions: List[AgentAction] = [] final_actions: list[AgentAction] = []
for action in tool_actions: for action in tool_actions:
if isinstance(action, ToolAgentAction): if isinstance(action, ToolAgentAction):
final_actions.append( final_actions.append(
@ -54,12 +54,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
return "openai-tools-agent-output-parser" return "openai-tools-agent-output-parser"
def parse_result( def parse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output") raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message message = result[0].message
return parse_ai_message_to_openai_tool_action(message) return parse_ai_message_to_openai_tool_action(message)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages") raise ValueError("Can only parse messages")

View File

@ -1,6 +1,7 @@
import json import json
import re import re
from typing import Pattern, Union from re import Pattern
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException

View File

@ -1,4 +1,5 @@
from typing import Sequence, Union from collections.abc import Sequence
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException

View File

@ -1,6 +1,6 @@
import json import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import List, Union from typing import Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -21,12 +21,12 @@ class ToolAgentAction(AgentActionMessageLog): # type: ignore[override]
def parse_ai_message_to_tool_action( def parse_ai_message_to_tool_action(
message: BaseMessage, message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls.""" """Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage): if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}") raise TypeError(f"Expected an AI message got {type(message)}")
actions: List = [] actions: list = []
if message.tool_calls: if message.tool_calls:
tool_calls = message.tool_calls tool_calls = message.tool_calls
else: else:
@ -91,12 +91,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
return "tools-agent-output-parser" return "tools-agent-output-parser"
def parse_result( def parse_result(
self, result: List[Generation], *, partial: bool = False self, result: list[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]: ) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration): if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output") raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message message = result[0].message
return parse_ai_message_to_tool_action(message) return parse_ai_message_to_tool_action(message)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages") raise ValueError("Can only parse messages")

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Optional, Sequence, Union from collections.abc import Sequence
from typing import Optional, Union
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
@ -20,7 +21,7 @@ def create_react_agent(
output_parser: Optional[AgentOutputParser] = None, output_parser: Optional[AgentOutputParser] = None,
tools_renderer: ToolsRenderer = render_text_description, tools_renderer: ToolsRenderer = render_text_description,
*, *,
stop_sequence: Union[bool, List[str]] = True, stop_sequence: Union[bool, list[str]] = True,
) -> Runnable: ) -> Runnable:
"""Create an agent that uses ReAct prompting. """Create an agent that uses ReAct prompting.

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.documents import Document from langchain_core.documents import Document
@ -65,7 +66,7 @@ class ReActDocstoreAgent(Agent):
return "Observation: " return "Observation: "
@property @property
def _stop(self) -> List[str]: def _stop(self) -> list[str]:
return ["\nObservation:"] return ["\nObservation:"]
@property @property
@ -122,7 +123,7 @@ class DocstoreExplorer:
return self._paragraphs[0] return self._paragraphs[0]
@property @property
def _paragraphs(self) -> List[str]: def _paragraphs(self) -> list[str]:
if self.document is None: if self.document is None:
raise ValueError("Cannot get paragraphs without a document") raise ValueError("Cannot get paragraphs without a document")
return self.document.page_content.split("\n\n") return self.document.page_content.split("\n\n")

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple from typing import Any
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate
@ -12,7 +12,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
return False return False
def _construct_agent_scratchpad( def _construct_agent_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str: ) -> str:
if len(intermediate_steps) == 0: if len(intermediate_steps) == 0:
return "" return ""
@ -26,7 +26,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
f"you return as final answer):\n{thoughts}" f"you return as final answer):\n{thoughts}"
) )
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
intermediate_steps = kwargs.pop("intermediate_steps") intermediate_steps = kwargs.pop("intermediate_steps")
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad( kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
intermediate_steps intermediate_steps

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Union from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel

View File

@ -1,5 +1,6 @@
import re import re
from typing import Any, List, Optional, Sequence, Tuple, Union from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
@ -49,7 +50,7 @@ class StructuredChatAgent(Agent):
return "Thought:" return "Thought:"
def _construct_scratchpad( def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]] self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str: ) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps) agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str): if not isinstance(agent_scratchpad, str):
@ -74,7 +75,7 @@ class StructuredChatAgent(Agent):
return StructuredChatOutputParserWithRetries.from_llm(llm=llm) return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
@property @property
def _stop(self) -> List[str]: def _stop(self) -> list[str]:
return ["Observation:"] return ["Observation:"]
@classmethod @classmethod
@ -85,8 +86,8 @@ class StructuredChatAgent(Agent):
suffix: str = SUFFIX, suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
tool_strings = [] tool_strings = []
for tool in tools: for tool in tools:
@ -117,8 +118,8 @@ class StructuredChatAgent(Agent):
suffix: str = SUFFIX, suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE, human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS, format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None, input_variables: Optional[list[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None, memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any, **kwargs: Any,
) -> Agent: ) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
@ -157,7 +158,7 @@ def create_structured_chat_agent(
prompt: ChatPromptTemplate, prompt: ChatPromptTemplate,
tools_renderer: ToolsRenderer = render_text_description_and_args, tools_renderer: ToolsRenderer = render_text_description_and_args,
*, *,
stop_sequence: Union[bool, List[str]] = True, stop_sequence: Union[bool, list[str]] = True,
) -> Runnable: ) -> Runnable:
"""Create an agent aimed at supporting tools with multiple inputs. """Create an agent aimed at supporting tools with multiple inputs.

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import json import json
import logging import logging
import re import re
from typing import Optional, Pattern, Union from re import Pattern
from typing import Optional, Union
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException

View File

@ -1,4 +1,5 @@
from typing import Callable, List, Sequence, Tuple from collections.abc import Sequence
from typing import Callable
from langchain_core.agents import AgentAction from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -12,7 +13,7 @@ from langchain.agents.format_scratchpad.tools import (
) )
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]] MessageFormatter = Callable[[Sequence[tuple[AgentAction, str]]], list[BaseMessage]]
def create_tool_calling_agent( def create_tool_calling_agent(

View File

@ -1,6 +1,6 @@
"""Interface for tools.""" """Interface for tools."""
from typing import List, Optional from typing import Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
@ -20,7 +20,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
def _run( def _run(
self, self,
requested_tool_name: str, requested_tool_name: str,
available_tool_names: List[str], available_tool_names: list[str],
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Use the tool.""" """Use the tool."""
@ -33,7 +33,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
async def _arun( async def _arun(
self, self,
requested_tool_name: str, requested_tool_name: str,
available_tool_names: List[str], available_tool_names: list[str],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str: ) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""

View File

@ -1,4 +1,4 @@
from typing import Dict, Type, Union from typing import Union
from langchain.agents.agent import BaseSingleActionAgent from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
@ -12,9 +12,9 @@ from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.structured_chat.base import StructuredChatAgent from langchain.agents.structured_chat.base import StructuredChatAgent
AGENT_TYPE = Union[Type[BaseSingleActionAgent], Type[OpenAIMultiFunctionsAgent]] AGENT_TYPE = Union[type[BaseSingleActionAgent], type[OpenAIMultiFunctionsAgent]]
AGENT_TO_CLASS: Dict[AgentType, AGENT_TYPE] = { AGENT_TO_CLASS: dict[AgentType, AGENT_TYPE] = {
AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent, AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent,
AgentType.REACT_DOCSTORE: ReActDocstoreAgent, AgentType.REACT_DOCSTORE: ReActDocstoreAgent,
AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent, AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent,

View File

@ -1,4 +1,4 @@
from typing import Sequence from collections.abc import Sequence
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool

View File

@ -1,4 +1,5 @@
from typing import Any, List, Sequence, Tuple, Union from collections.abc import Sequence
from typing import Any, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentFinish from langchain_core.agents import AgentAction, AgentFinish
@ -38,13 +39,13 @@ class XMLAgent(BaseSingleActionAgent):
""" """
tools: List[BaseTool] tools: list[BaseTool]
"""List of tools this agent has access to.""" """List of tools this agent has access to."""
llm_chain: LLMChain llm_chain: LLMChain
"""Chain to use to predict action.""" """Chain to use to predict action."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
return ["input"] return ["input"]
@staticmethod @staticmethod
@ -60,7 +61,7 @@ class XMLAgent(BaseSingleActionAgent):
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -84,7 +85,7 @@ class XMLAgent(BaseSingleActionAgent):
async def aplan( async def aplan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Union[AgentAction, AgentFinish]: ) -> Union[AgentAction, AgentFinish]:
@ -113,7 +114,7 @@ def create_xml_agent(
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
tools_renderer: ToolsRenderer = render_text_description, tools_renderer: ToolsRenderer = render_text_description,
*, *,
stop_sequence: Union[bool, List[str]] = True, stop_sequence: Union[bool, list[str]] = True,
) -> Runnable: ) -> Runnable:
"""Create an agent that uses XML to format its logic. """Create an agent that uses XML to format its logic.

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast from collections.abc import AsyncIterator
from typing import Any, Literal, Union, cast
from langchain_core.callbacks import AsyncCallbackHandler from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
@ -25,7 +26,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.done = asyncio.Event() self.done = asyncio.Event()
async def on_llm_start( async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None: ) -> None:
# If two calls are made in a row, this resets the state # If two calls are made in a row, this resets the state
self.done.clear() self.done.clear()

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
@ -30,7 +30,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__( def __init__(
self, self,
*, *,
answer_prefix_tokens: Optional[List[str]] = None, answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True, strip_tokens: bool = True,
stream_prefix: bool = False, stream_prefix: bool = False,
) -> None: ) -> None:
@ -62,7 +62,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.answer_reached = False self.answer_reached = False
async def on_llm_start( async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None: ) -> None:
# If two calls are made in a row, this resets the state # If two calls are made in a row, this resets the state
self.done.clear() self.done.clear()

View File

@ -1,7 +1,7 @@
"""Callback Handler streams to stdout on new llm token.""" """Callback Handler streams to stdout on new llm token."""
import sys import sys
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.callbacks import StreamingStdOutCallbackHandler from langchain_core.callbacks import StreamingStdOutCallbackHandler
@ -31,7 +31,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def __init__( def __init__(
self, self,
*, *,
answer_prefix_tokens: Optional[List[str]] = None, answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True, strip_tokens: bool = True,
stream_prefix: bool = False, stream_prefix: bool = False,
) -> None: ) -> None:
@ -63,7 +63,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.answer_reached = False self.answer_reached = False
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
self.answer_reached = False self.answer_reached = False

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -20,7 +21,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]: def _extract_scheme_and_domain(url: str) -> tuple[str, str]:
"""Extract the scheme + domain from a given URL. """Extract the scheme + domain from a given URL.
Args: Args:
@ -215,7 +216,7 @@ try:
""" """
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -223,7 +224,7 @@ try:
return [self.question_key] return [self.question_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Expect output key. """Expect output key.
:meta private: :meta private:
@ -243,7 +244,7 @@ try:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_limit_to_domains(cls, values: Dict) -> Any: def validate_limit_to_domains(cls, values: dict) -> Any:
"""Check that allowed domains are valid.""" """Check that allowed domains are valid."""
# This check must be a pre=True check, so that a default of None # This check must be a pre=True check, so that a default of None
# won't be set to limit_to_domains if it's not provided. # won't be set to limit_to_domains if it's not provided.
@ -275,9 +276,9 @@ try:
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key] question = inputs[self.question_key]
api_url = self.api_request_chain.predict( api_url = self.api_request_chain.predict(
@ -308,9 +309,9 @@ try:
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = ( _run_manager = (
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
) )

View File

@ -1,12 +1,13 @@
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
import builtins
import inspect import inspect
import json import json
import logging import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union, cast from typing import Any, Optional, Union, cast
import yaml import yaml
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -46,7 +47,7 @@ def _get_verbosity() -> bool:
return get_verbose() return get_verbose()
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components. """Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like Chains should be used to encode a sequence of calls to components like
@ -86,13 +87,13 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs """Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value, will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`.""" accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None tags: Optional[list[str]] = None
"""Optional list of tags associated with the chain. Defaults to None. """Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain, These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case. You can use these to eg identify a specific instance of a chain with its use case.
""" """
metadata: Optional[Dict[str, Any]] = None metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None. """Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain, This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`. and passed as arguments to the handlers defined in `callbacks`.
@ -107,7 +108,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys} "ChainInput", **{k: (Any, None) for k in self.input_keys}
@ -115,7 +116,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys} "ChainOutput", **{k: (Any, None) for k in self.output_keys}
@ -123,10 +124,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def invoke( def invoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
config = ensure_config(config) config = ensure_config(config)
callbacks = config.get("callbacks") callbacks = config.get("callbacks")
tags = config.get("tags") tags = config.get("tags")
@ -162,7 +163,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else self._call(inputs) else self._call(inputs)
) )
final_outputs: Dict[str, Any] = self.prep_outputs( final_outputs: dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs inputs, outputs, return_only_outputs
) )
except BaseException as e: except BaseException as e:
@ -176,10 +177,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def ainvoke( async def ainvoke(
self, self,
input: Dict[str, Any], input: dict[str, Any],
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Any, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
config = ensure_config(config) config = ensure_config(config)
callbacks = config.get("callbacks") callbacks = config.get("callbacks")
tags = config.get("tags") tags = config.get("tags")
@ -213,7 +214,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported if new_arg_supported
else await self._acall(inputs) else await self._acall(inputs)
) )
final_outputs: Dict[str, Any] = await self.aprep_outputs( final_outputs: dict[str, Any] = await self.aprep_outputs(
inputs, outputs, return_only_outputs inputs, outputs, return_only_outputs
) )
except BaseException as e: except BaseException as e:
@ -231,7 +232,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_callback_manager_deprecation(cls, values: Dict) -> Any: def raise_callback_manager_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used.""" """Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None: if values.get("callback_manager") is not None:
if values.get("callbacks") is not None: if values.get("callbacks") is not None:
@ -261,15 +262,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@property @property
@abstractmethod @abstractmethod
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Keys expected to be in the chain input.""" """Keys expected to be in the chain input."""
@property @property
@abstractmethod @abstractmethod
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Keys expected to be in the chain output.""" """Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None: def _validate_inputs(self, inputs: dict[str, Any]) -> None:
"""Check that all inputs are present.""" """Check that all inputs are present."""
if not isinstance(inputs, dict): if not isinstance(inputs, dict):
_input_keys = set(self.input_keys) _input_keys = set(self.input_keys)
@ -289,7 +290,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if missing_keys: if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}") raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None: def _validate_outputs(self, outputs: dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs) missing_keys = set(self.output_keys).difference(outputs)
if missing_keys: if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}") raise ValueError(f"Missing some output keys: {missing_keys}")
@ -297,9 +298,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@abstractmethod @abstractmethod
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Execute the chain. """Execute the chain.
This is a private method that is not user-facing. It is only called within This is a private method that is not user-facing. It is only called within
@ -319,9 +320,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously execute the chain. """Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within This is a private method that is not user-facing. It is only called within
@ -345,15 +346,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="invoke", removal="1.0") @deprecated("0.1.0", alternative="invoke", removal="1.0")
def __call__( def __call__(
self, self,
inputs: Union[Dict[str, Any], Any], inputs: Union[dict[str, Any], Any],
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Execute the chain. """Execute the chain.
Args: Args:
@ -396,15 +397,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="ainvoke", removal="1.0") @deprecated("0.1.0", alternative="ainvoke", removal="1.0")
async def acall( async def acall(
self, self,
inputs: Union[Dict[str, Any], Any], inputs: Union[dict[str, Any], Any],
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Asynchronously execute the chain. """Asynchronously execute the chain.
Args: Args:
@ -445,10 +446,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def prep_outputs( def prep_outputs(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
outputs: Dict[str, str], outputs: dict[str, str],
return_only_outputs: bool = False, return_only_outputs: bool = False,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory. """Validate and prepare chain outputs, and save info about this run to memory.
Args: Args:
@ -471,10 +472,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def aprep_outputs( async def aprep_outputs(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
outputs: Dict[str, str], outputs: dict[str, str],
return_only_outputs: bool = False, return_only_outputs: bool = False,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory. """Validate and prepare chain outputs, and save info about this run to memory.
Args: Args:
@ -495,7 +496,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else: else:
return {**inputs, **outputs} return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory. """Prepare chain inputs, including adding inputs from memory.
Args: Args:
@ -519,7 +520,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
inputs = dict(inputs, **external_context) inputs = dict(inputs, **external_context)
return inputs return inputs
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: async def aprep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory. """Prepare chain inputs, including adding inputs from memory.
Args: Args:
@ -557,8 +558,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self, self,
*args: Any, *args: Any,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Convenience method for executing chain. """Convenience method for executing chain.
@ -628,8 +629,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self, self,
*args: Any, *args: Any,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Convenience method for executing chain. """Convenience method for executing chain.
@ -695,7 +696,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}." f" but not both. Got args: {args} and kwargs: {kwargs}."
) )
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> dict:
"""Dictionary representation of chain. """Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be Expects `Chain._chain_type` property to be implemented and for memory to be
@ -763,7 +764,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="batch", removal="1.0") @deprecated("0.1.0", alternative="batch", removal="1.0")
def apply( def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]: ) -> list[builtins.dict[str, str]]:
"""Call the chain on all inputs in the list.""" """Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list] return [self(inputs, callbacks=callbacks) for inputs in input_list]

View File

@ -1,7 +1,7 @@
"""Base interface for chains combining documents.""" """Base interface for chains combining documents."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -47,22 +47,22 @@ class BaseCombineDocumentsChain(Chain, ABC):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsInput", "CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload] **{self.input_key: (list[Document], None)}, # type: ignore[call-overload]
) )
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsOutput", "CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload] **{self.output_key: (str, None)}, # type: ignore[call-overload]
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -70,14 +70,14 @@ class BaseCombineDocumentsChain(Chain, ABC):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
""" """
return [self.output_key] return [self.output_key]
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in. """Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list This can be used by a caller to determine whether passing in a list
@ -96,7 +96,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
return None return None
@abstractmethod @abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: def combine_docs(self, docs: list[Document], **kwargs: Any) -> tuple[str, dict]:
"""Combine documents into a single string. """Combine documents into a single string.
Args: Args:
@ -111,8 +111,8 @@ class BaseCombineDocumentsChain(Chain, ABC):
@abstractmethod @abstractmethod
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], **kwargs: Any self, docs: list[Document], **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine documents into a single string. """Combine documents into a single string.
Args: Args:
@ -127,9 +127,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
def _call( def _call(
self, self,
inputs: Dict[str, List[Document]], inputs: dict[str, list[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs.""" """Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key] docs = inputs[self.input_key]
@ -143,9 +143,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, List[Document]], inputs: dict[str, list[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs.""" """Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key] docs = inputs[self.input_key]
@ -229,7 +229,7 @@ class AnalyzeDocumentChain(Chain):
combine_docs_chain: BaseCombineDocumentsChain combine_docs_chain: BaseCombineDocumentsChain
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -237,7 +237,7 @@ class AnalyzeDocumentChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -246,7 +246,7 @@ class AnalyzeDocumentChain(Chain):
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return create_model( return create_model(
"AnalyzeDocumentChain", "AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload] **{self.input_key: (str, None)}, # type: ignore[call-overload]
@ -254,20 +254,20 @@ class AnalyzeDocumentChain(Chain):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config) return self.combine_docs_chain.get_output_schema(config)
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Split document into chunks and pass to CombineDocumentsChain.""" """Split document into chunks and pass to CombineDocumentsChain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key] document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document]) docs = self.text_splitter.create_documents([document])
# Other keys are assumed to be needed for LLM prediction # Other keys are assumed to be needed for LLM prediction
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key} other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs other_keys[self.combine_docs_chain.input_key] = docs
return self.combine_docs_chain( return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child() other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -113,20 +113,20 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
if self.return_intermediate_steps: if self.return_intermediate_steps:
return create_model( return create_model(
"MapReduceDocumentsOutput", "MapReduceDocumentsOutput",
**{ **{
self.output_key: (str, None), self.output_key: (str, None),
"intermediate_steps": (List[str], None), "intermediate_steps": (list[str], None),
}, # type: ignore[call-overload] }, # type: ignore[call-overload]
) )
return super().get_output_schema(config) return super().get_output_schema(config)
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -143,7 +143,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_reduce_chain(cls, values: Dict) -> Any: def get_reduce_chain(cls, values: dict) -> Any:
"""For backwards compatibility.""" """For backwards compatibility."""
if "combine_document_chain" in values: if "combine_document_chain" in values:
if "reduce_documents_chain" in values: if "reduce_documents_chain" in values:
@ -167,7 +167,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_return_intermediate_steps(cls, values: Dict) -> Any: def get_return_intermediate_steps(cls, values: dict) -> Any:
"""For backwards compatibility.""" """For backwards compatibility."""
if "return_map_steps" in values: if "return_map_steps" in values:
values["return_intermediate_steps"] = values["return_map_steps"] values["return_intermediate_steps"] = values["return_map_steps"]
@ -176,7 +176,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any: def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided.""" """Get default document variable name, if not provided."""
if "llm_chain" not in values: if "llm_chain" not in values:
raise ValueError("llm_chain must be provided") raise ValueError("llm_chain must be provided")
@ -227,11 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs( def combine_docs(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine documents in a map reduce manner. """Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results. Combine by mapping first chain over all documents, then reducing the results.
@ -258,11 +258,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine documents in a map reduce manner. """Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results. Combine by mapping first chain over all documents, then reducing the results.

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -79,7 +80,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Key in output of llm_chain to rank on.""" """Key in output of llm_chain to rank on."""
answer_key: str answer_key: str
"""Key in output of llm_chain to return as answer.""" """Key in output of llm_chain to return as answer."""
metadata_keys: Optional[List[str]] = None metadata_keys: Optional[list[str]] = None
"""Additional metadata from the chosen document to return.""" """Additional metadata from the chosen document to return."""
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return intermediate steps. """Return intermediate steps.
@ -92,19 +93,19 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
def get_output_schema( def get_output_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
schema: Dict[str, Any] = { schema: dict[str, Any] = {
self.output_key: (str, None), self.output_key: (str, None),
} }
if self.return_intermediate_steps: if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None) schema["intermediate_steps"] = (list[str], None)
if self.metadata_keys: if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys}) schema.update({key: (Any, None) for key in self.metadata_keys})
return create_model("MapRerankOutput", **schema) return create_model("MapRerankOutput", **schema)
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -140,7 +141,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any: def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided.""" """Get default document variable name, if not provided."""
if "llm_chain" not in values: if "llm_chain" not in values:
raise ValueError("llm_chain must be provided") raise ValueError("llm_chain must be provided")
@ -163,8 +164,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
return values return values
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine documents in a map rerank manner. """Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results. Combine by mapping first chain over all documents, then reranking the results.
@ -187,8 +188,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
return self._process_results(docs, results) return self._process_results(docs, results)
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine documents in a map rerank manner. """Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results. Combine by mapping first chain over all documents, then reranking the results.
@ -212,10 +213,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
def _process_results( def _process_results(
self, self,
docs: List[Document], docs: list[Document],
results: Sequence[Union[str, List[str], Dict[str, str]]], results: Sequence[Union[str, list[str], dict[str, str]]],
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
typed_results = cast(List[dict], results) typed_results = cast(list[dict], results)
sorted_res = sorted( sorted_res = sorted(
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
) )

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple from typing import Any, Callable, Optional, Protocol
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -15,20 +15,20 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
class CombineDocsProtocol(Protocol): class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method.""" """Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str: def __call__(self, docs: list[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method.""" """Interface for the combine_docs method."""
class AsyncCombineDocsProtocol(Protocol): class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method.""" """Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str: async def __call__(self, docs: list[Document], **kwargs: Any) -> str:
"""Async interface for the combine_docs method.""" """Async interface for the combine_docs method."""
def split_list_of_docs( def split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]: ) -> list[list[Document]]:
"""Split Documents into subsets that each meet a cumulative length constraint. """Split Documents into subsets that each meet a cumulative length constraint.
Args: Args:
@ -59,7 +59,7 @@ def split_list_of_docs(
def collapse_docs( def collapse_docs(
docs: List[Document], docs: list[Document],
combine_document_func: CombineDocsProtocol, combine_document_func: CombineDocsProtocol,
**kwargs: Any, **kwargs: Any,
) -> Document: ) -> Document:
@ -91,7 +91,7 @@ def collapse_docs(
async def acollapse_docs( async def acollapse_docs(
docs: List[Document], docs: list[Document],
combine_document_func: AsyncCombineDocsProtocol, combine_document_func: AsyncCombineDocsProtocol,
**kwargs: Any, **kwargs: Any,
) -> Document: ) -> Document:
@ -229,11 +229,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs( def combine_docs(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine multiple documents recursively. """Combine multiple documents recursively.
Args: Args:
@ -258,11 +258,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs( async def acombine_docs(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Async combine multiple documents recursively. """Async combine multiple documents recursively.
Args: Args:
@ -287,16 +287,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def _collapse( def _collapse(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[List[Document], dict]: ) -> tuple[list[Document], dict]:
result_docs = docs result_docs = docs
length_func = self.combine_documents_chain.prompt_length length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs) num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
return self._collapse_chain.run( return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs input_documents=docs, callbacks=callbacks, **kwargs
) )
@ -322,16 +322,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def _acollapse( async def _acollapse(
self, self,
docs: List[Document], docs: list[Document],
token_max: Optional[int] = None, token_max: Optional[int] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> Tuple[List[Document], dict]: ) -> tuple[list[Document], dict]:
result_docs = docs result_docs = docs
length_func = self.combine_documents_chain.prompt_length length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs) num_tokens = length_func(result_docs, **kwargs)
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: async def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
return await self._collapse_chain.arun( return await self._collapse_chain.arun(
input_documents=docs, callbacks=callbacks, **kwargs input_documents=docs, callbacks=callbacks, **kwargs
) )

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Tuple from typing import Any
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -98,7 +98,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Return the results of the refine steps in the output.""" """Return the results of the refine steps in the output."""
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -115,7 +115,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_return_intermediate_steps(cls, values: Dict) -> Any: def get_return_intermediate_steps(cls, values: dict) -> Any:
"""For backwards compatibility.""" """For backwards compatibility."""
if "return_refine_steps" in values: if "return_refine_steps" in values:
values["return_intermediate_steps"] = values["return_refine_steps"] values["return_intermediate_steps"] = values["return_refine_steps"]
@ -124,7 +124,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any: def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided.""" """Get default document variable name, if not provided."""
if "initial_llm_chain" not in values: if "initial_llm_chain" not in values:
raise ValueError("initial_llm_chain must be provided") raise ValueError("initial_llm_chain must be provided")
@ -147,8 +147,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return values return values
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain. """Combine by mapping first chain over all, then stuffing into final chain.
Args: Args:
@ -172,8 +172,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return self._construct_result(refine_steps, res) return self._construct_result(refine_steps, res)
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Async combine by mapping a first chain over all, then stuffing """Async combine by mapping a first chain over all, then stuffing
into a final chain. into a final chain.
@ -197,22 +197,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
refine_steps.append(res) refine_steps.append(res)
return self._construct_result(refine_steps, res) return self._construct_result(refine_steps, res)
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]: def _construct_result(self, refine_steps: list[str], res: str) -> tuple[str, dict]:
if self.return_intermediate_steps: if self.return_intermediate_steps:
extra_return_dict = {"intermediate_steps": refine_steps} extra_return_dict = {"intermediate_steps": refine_steps}
else: else:
extra_return_dict = {} extra_return_dict = {}
return res, extra_return_dict return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]: def _construct_refine_inputs(self, doc: Document, res: str) -> dict[str, Any]:
return { return {
self.document_variable_name: format_document(doc, self.document_prompt), self.document_variable_name: format_document(doc, self.document_prompt),
self.initial_response_name: res, self.initial_response_name: res,
} }
def _construct_initial_inputs( def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any self, docs: list[Document], **kwargs: Any
) -> Dict[str, Any]: ) -> dict[str, Any]:
base_info = {"page_content": docs[0].page_content} base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata) base_info.update(docs[0].metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables} document_info = {k: base_info[k] for k in self.document_prompt.input_variables}

View File

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context.""" """Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
@ -29,7 +29,7 @@ def create_stuff_documents_chain(
document_prompt: Optional[BasePromptTemplate] = None, document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR, document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
document_variable_name: str = DOCUMENTS_KEY, document_variable_name: str = DOCUMENTS_KEY,
) -> Runnable[Dict[str, Any], Any]: ) -> Runnable[dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model. """Create a chain for passing a list of Documents to a model.
Args: Args:
@ -163,7 +163,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any: def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided. """Get default document variable name, if not provided.
If only one variable is present in the llm_chain.prompt, If only one variable is present in the llm_chain.prompt,
@ -188,13 +188,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
extra_keys = [ extra_keys = [
k for k in self.llm_chain.input_keys if k != self.document_variable_name k for k in self.llm_chain.input_keys if k != self.document_variable_name
] ]
return super().input_keys + extra_keys return super().input_keys + extra_keys
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: def _get_inputs(self, docs: list[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs. """Construct inputs from kwargs and docs.
Format and then join all the documents together into one input with name Format and then join all the documents together into one input with name
@ -220,7 +220,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
inputs[self.document_variable_name] = self.document_separator.join(doc_strings) inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in. """Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list This can be used by a caller to determine whether passing in a list
@ -241,8 +241,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return self.llm_chain._get_num_tokens(prompt) return self.llm_chain._get_num_tokens(prompt)
def combine_docs( def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM. """Stuff all documents into one prompt and pass to LLM.
Args: Args:
@ -259,8 +259,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return self.llm_chain.predict(callbacks=callbacks, **inputs), {} return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]: ) -> tuple[str, dict]:
"""Async stuff all documents into one prompt and pass to LLM. """Async stuff all documents into one prompt and pass to LLM.
Args: Args:

View File

@ -1,6 +1,6 @@
"""Chain for applying constitutional principles to the outputs of another chain.""" """Chain for applying constitutional principles to the outputs of another chain."""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
@ -190,15 +190,15 @@ class ConstitutionalChain(Chain):
""" # noqa: E501 """ # noqa: E501
chain: LLMChain chain: LLMChain
constitutional_principles: List[ConstitutionalPrinciple] constitutional_principles: list[ConstitutionalPrinciple]
critique_chain: LLMChain critique_chain: LLMChain
revision_chain: LLMChain revision_chain: LLMChain
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
@classmethod @classmethod
def get_principles( def get_principles(
cls, names: Optional[List[str]] = None cls, names: Optional[list[str]] = None
) -> List[ConstitutionalPrinciple]: ) -> list[ConstitutionalPrinciple]:
if names is None: if names is None:
return list(PRINCIPLES.values()) return list(PRINCIPLES.values())
else: else:
@ -224,12 +224,12 @@ class ConstitutionalChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys.""" """Input keys."""
return self.chain.input_keys return self.chain.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Output keys.""" """Output keys."""
if self.return_intermediate_steps: if self.return_intermediate_steps:
return ["output", "critiques_and_revisions", "initial_output"] return ["output", "critiques_and_revisions", "initial_output"]
@ -237,9 +237,9 @@ class ConstitutionalChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run( response = self.chain.run(
**inputs, **inputs,
@ -305,7 +305,7 @@ class ConstitutionalChain(Chain):
color="yellow", color="yellow",
) )
final_output: Dict[str, Any] = {"output": response} final_output: dict[str, Any] = {"output": response}
if self.return_intermediate_steps: if self.return_intermediate_steps:
final_output["initial_output"] = initial_response final_output["initial_output"] = initial_response
final_output["critiques_and_revisions"] = critiques_and_revisions final_output["critiques_and_revisions"] = critiques_and_revisions

View File

@ -1,7 +1,5 @@
"""Chain that carries on a conversation and calls an LLM.""" """Chain that carries on a conversation and calls an LLM."""
from typing import List
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
@ -121,7 +119,7 @@ class ConversationChain(LLMChain): # type: ignore[override, override]
return False return False
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Use this since so some prompt vars come from history.""" """Use this since so some prompt vars come from history."""
return [self.input_key] return [self.input_key]

View File

@ -6,7 +6,7 @@ import inspect
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -32,13 +32,13 @@ from langchain.chains.question_answering import load_qa_chain
# Depending on the memory type and configuration, the chat history format may differ. # Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated. # This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] CHAT_TURN_TYPE = Union[tuple[str, str], BaseMessage]
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} _ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
buffer = "" buffer = ""
for dialogue_turn in chat_history: for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage): if isinstance(dialogue_turn, BaseMessage):
@ -64,7 +64,7 @@ class InputType(BaseModel):
question: str question: str
"""The question to answer.""" """The question to answer."""
chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list) chat_history: list[CHAT_TURN_TYPE] = Field(default_factory=list)
"""The chat history to use for retrieval.""" """The chat history to use for retrieval."""
@ -89,7 +89,7 @@ class BaseConversationalRetrievalChain(Chain):
"""Return the retrieved source documents as part of the final result.""" """Return the retrieved source documents as part of the final result."""
return_generated_question: bool = False return_generated_question: bool = False
"""Return the generated question as part of the final result.""" """Return the generated question as part of the final result."""
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None get_chat_history: Optional[Callable[[list[CHAT_TURN_TYPE]], str]] = None
"""An optional function to get a string of the chat history. """An optional function to get a string of the chat history.
If None is provided, will use a default.""" If None is provided, will use a default."""
response_if_no_docs_found: Optional[str] = None response_if_no_docs_found: Optional[str] = None
@ -103,17 +103,17 @@ class BaseConversationalRetrievalChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys.""" """Input keys."""
return ["question", "chat_history"] return ["question", "chat_history"]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]: ) -> type[BaseModel]:
return InputType return InputType
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return the output keys. """Return the output keys.
:meta private: :meta private:
@ -129,17 +129,17 @@ class BaseConversationalRetrievalChain(Chain):
def _get_docs( def _get_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs["question"] question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history get_chat_history = self.get_chat_history or _get_chat_history
@ -159,7 +159,7 @@ class BaseConversationalRetrievalChain(Chain):
docs = self._get_docs(new_question, inputs, run_manager=_run_manager) docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else: else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {} output: dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0: if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found output[self.output_key] = self.response_if_no_docs_found
else: else:
@ -182,17 +182,17 @@ class BaseConversationalRetrievalChain(Chain):
async def _aget_docs( async def _aget_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs["question"] question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history get_chat_history = self.get_chat_history or _get_chat_history
@ -212,7 +212,7 @@ class BaseConversationalRetrievalChain(Chain):
else: else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg] docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {} output: dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0: if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found output[self.output_key] = self.response_if_no_docs_found
else: else:
@ -368,7 +368,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""If set, enforces that the documents returned are less than this limit. """If set, enforces that the documents returned are less than this limit.
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain.""" This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs) num_docs = len(docs)
if self.max_tokens_limit and isinstance( if self.max_tokens_limit and isinstance(
@ -388,10 +388,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
def _get_docs( def _get_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
docs = self.retriever.invoke( docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}
@ -401,10 +401,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
async def _aget_docs( async def _aget_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
docs = await self.retriever.ainvoke( docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}
@ -420,7 +420,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
chain_type: str = "stuff", chain_type: str = "stuff",
verbose: bool = False, verbose: bool = False,
condense_question_llm: Optional[BaseLanguageModel] = None, condense_question_llm: Optional[BaseLanguageModel] = None,
combine_docs_chain_kwargs: Optional[Dict] = None, combine_docs_chain_kwargs: Optional[dict] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> BaseConversationalRetrievalChain: ) -> BaseConversationalRetrievalChain:
@ -485,7 +485,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
warnings.warn( warnings.warn(
"`ChatVectorDBChain` is deprecated - " "`ChatVectorDBChain` is deprecated - "
"please use `from langchain.chains import ConversationalRetrievalChain`" "please use `from langchain.chains import ConversationalRetrievalChain`"
@ -495,10 +495,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
def _get_docs( def _get_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
vectordbkwargs = inputs.get("vectordbkwargs", {}) vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs} full_kwargs = {**self.search_kwargs, **vectordbkwargs}
@ -509,10 +509,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
async def _aget_docs( async def _aget_docs(
self, self,
question: str, question: str,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
raise NotImplementedError("ChatVectorDBChain does not support async") raise NotImplementedError("ChatVectorDBChain does not support async")
@ -523,7 +523,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
vectorstore: VectorStore, vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
chain_type: str = "stuff", chain_type: str = "stuff",
combine_docs_chain_kwargs: Optional[Dict] = None, combine_docs_chain_kwargs: Optional[dict] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any, **kwargs: Any,
) -> BaseConversationalRetrievalChain: ) -> BaseConversationalRetrievalChain:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Optional
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -44,8 +44,8 @@ class ElasticsearchDatabaseChain(Chain):
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch.""" """Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
top_k: int = 10 top_k: int = 10
"""Number of results to return from the query""" """Number of results to return from the query"""
ignore_indices: Optional[List[str]] = None ignore_indices: Optional[list[str]] = None
include_indices: Optional[List[str]] = None include_indices: Optional[list[str]] = None
input_key: str = "question" #: :meta private: input_key: str = "question" #: :meta private:
output_key: str = "result" #: :meta private: output_key: str = "result" #: :meta private:
sample_documents_in_index_info: int = 3 sample_documents_in_index_info: int = 3
@ -66,7 +66,7 @@ class ElasticsearchDatabaseChain(Chain):
return self return self
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the singular input key. """Return the singular input key.
:meta private: :meta private:
@ -74,7 +74,7 @@ class ElasticsearchDatabaseChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return the singular output key. """Return the singular output key.
:meta private: :meta private:
@ -84,7 +84,7 @@ class ElasticsearchDatabaseChain(Chain):
else: else:
return [self.output_key, INTERMEDIATE_STEPS_KEY] return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> List[str]: def _list_indices(self) -> list[str]:
all_indices = [ all_indices = [
index["index"] for index in self.database.cat.indices(format="json") index["index"] for index in self.database.cat.indices(format="json")
] ]
@ -96,7 +96,7 @@ class ElasticsearchDatabaseChain(Chain):
return all_indices return all_indices
def _get_indices_infos(self, indices: List[str]) -> str: def _get_indices_infos(self, indices: list[str]) -> str:
mappings = self.database.indices.get_mapping(index=",".join(indices)) mappings = self.database.indices.get_mapping(index=",".join(indices))
if self.sample_documents_in_index_info > 0: if self.sample_documents_in_index_info > 0:
for k, v in mappings.items(): for k, v in mappings.items():
@ -114,15 +114,15 @@ class ElasticsearchDatabaseChain(Chain):
] ]
) )
def _search(self, indices: List[str], query: str) -> str: def _search(self, indices: list[str], query: str) -> str:
result = self.database.search(index=",".join(indices), body=query) result = self.database.search(index=",".join(indices), body=query)
return str(result) return str(result)
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\nESQuery:" input_text = f"{inputs[self.input_key]}\nESQuery:"
_run_manager.on_text(input_text, verbose=self.verbose) _run_manager.on_text(input_text, verbose=self.verbose)
@ -134,7 +134,7 @@ class ElasticsearchDatabaseChain(Chain):
"indices_info": indices_info, "indices_info": indices_info,
"stop": ["\nESResult:"], "stop": ["\nESResult:"],
} }
intermediate_steps: List = [] intermediate_steps: list = []
try: try:
intermediate_steps.append(query_inputs) # input: es generation intermediate_steps.append(query_inputs) # input: es generation
es_cmd = self.query_chain.invoke( es_cmd = self.query_chain.invoke(
@ -163,7 +163,7 @@ class ElasticsearchDatabaseChain(Chain):
intermediate_steps.append(final_result) # output: final answer intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose) _run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result} chain_result: dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps: if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result return chain_result

View File

@ -1,5 +1,3 @@
from typing import List
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate
@ -9,7 +7,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example( def generate_example(
examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate examples: list[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
) -> str: ) -> str:
"""Return another example given a list of examples for a prompt.""" """Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate( prompt = FewShotPromptTemplate(

View File

@ -2,7 +2,8 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Sequence, Tuple from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
@ -26,7 +27,7 @@ from langchain.chains.llm import LLMChain
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]: def _extract_tokens_and_log_probs(response: AIMessage) -> tuple[list[str], list[float]]:
"""Extract tokens and log probabilities from chat model response.""" """Extract tokens and log probabilities from chat model response."""
tokens = [] tokens = []
log_probs = [] log_probs = []
@ -47,7 +48,7 @@ class QuestionGeneratorChain(LLMChain):
return False return False
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys for the chain.""" """Input keys for the chain."""
return ["user_input", "context", "response"] return ["user_input", "context", "response"]
@ -58,7 +59,7 @@ def _low_confidence_spans(
min_prob: float, min_prob: float,
min_token_gap: int, min_token_gap: int,
num_pad_tokens: int, num_pad_tokens: int,
) -> List[str]: ) -> list[str]:
try: try:
import numpy as np import numpy as np
@ -117,22 +118,22 @@ class FlareChain(Chain):
"""Whether to start with retrieval.""" """Whether to start with retrieval."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys for the chain.""" """Input keys for the chain."""
return ["user_input"] return ["user_input"]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Output keys for the chain.""" """Output keys for the chain."""
return ["response"] return ["response"]
def _do_generation( def _do_generation(
self, self,
questions: List[str], questions: list[str],
user_input: str, user_input: str,
response: str, response: str,
_run_manager: CallbackManagerForChainRun, _run_manager: CallbackManagerForChainRun,
) -> Tuple[str, bool]: ) -> tuple[str, bool]:
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
docs = [] docs = []
for question in questions: for question in questions:
@ -153,12 +154,12 @@ class FlareChain(Chain):
def _do_retrieval( def _do_retrieval(
self, self,
low_confidence_spans: List[str], low_confidence_spans: list[str],
_run_manager: CallbackManagerForChainRun, _run_manager: CallbackManagerForChainRun,
user_input: str, user_input: str,
response: str, response: str,
initial_response: str, initial_response: str,
) -> Tuple[str, bool]: ) -> tuple[str, bool]:
question_gen_inputs = [ question_gen_inputs = [
{ {
"user_input": user_input, "user_input": user_input,
@ -187,9 +188,9 @@ class FlareChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
user_input = inputs[self.input_keys[0]] user_input = inputs[self.input_keys[0]]

View File

@ -1,16 +1,14 @@
from typing import Tuple
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]): class FinishedOutputParser(BaseOutputParser[tuple[str, bool]]):
"""Output parser that checks if the output is finished.""" """Output parser that checks if the output is finished."""
finished_value: str = "FINISHED" finished_value: str = "FINISHED"
"""Value that indicates the output is finished.""" """Value that indicates the output is finished."""
def parse(self, text: str) -> Tuple[str, bool]: def parse(self, text: str) -> tuple[str, bool]:
cleaned = text.strip() cleaned = text.strip()
finished = self.finished_value in cleaned finished = self.finished_value in cleaned
return cleaned.replace(self.finished_value, ""), finished return cleaned.replace(self.finished_value, ""), finished

View File

@ -6,7 +6,7 @@ https://arxiv.org/abs/2212.10496
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -38,23 +38,23 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys for Hyde's LLM chain.""" """Input keys for Hyde's LLM chain."""
return self.llm_chain.input_schema.model_json_schema()["required"] return self.llm_chain.input_schema.model_json_schema()["required"]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Output keys for Hyde's LLM chain.""" """Output keys for Hyde's LLM chain."""
if isinstance(self.llm_chain, LLMChain): if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys return self.llm_chain.output_keys
else: else:
return ["text"] return ["text"]
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call the base embeddings.""" """Call the base embeddings."""
return self.base_embeddings.embed_documents(texts) return self.base_embeddings.embed_documents(texts)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]: def combine_embeddings(self, embeddings: list[list[float]]) -> list[float]:
"""Combine embeddings into final embeddings.""" """Combine embeddings into final embeddings."""
try: try:
import numpy as np import numpy as np
@ -73,7 +73,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
num_vectors = len(embeddings) num_vectors = len(embeddings)
return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)] return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)]
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Generate a hypothetical document and embedded it.""" """Generate a hypothetical document and embedded it."""
var_name = self.input_keys[0] var_name = self.input_keys[0]
result = self.llm_chain.invoke({var_name: text}) result = self.llm_chain.invoke({var_name: text})
@ -86,9 +86,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
"""Call the internal llm chain.""" """Call the internal llm chain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain.invoke( return self.llm_chain.invoke(

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -100,7 +101,7 @@ class LLMChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Will be whatever keys the prompt expects. """Will be whatever keys the prompt expects.
:meta private: :meta private:
@ -108,7 +109,7 @@ class LLMChain(Chain):
return self.prompt.input_variables return self.prompt.input_variables
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Will always return text key. """Will always return text key.
:meta private: :meta private:
@ -120,15 +121,15 @@ class LLMChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
response = self.generate([inputs], run_manager=run_manager) response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0] return self.create_outputs(response)[0]
def generate( def generate(
self, self,
input_list: List[Dict[str, Any]], input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult: ) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
@ -143,9 +144,9 @@ class LLMChain(Chain):
) )
else: else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch( results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks} cast(list, prompts), {"callbacks": callbacks}
) )
generations: List[List[Generation]] = [] generations: list[list[Generation]] = []
for res in results: for res in results:
if isinstance(res, BaseMessage): if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)]) generations.append([ChatGeneration(message=res)])
@ -155,7 +156,7 @@ class LLMChain(Chain):
async def agenerate( async def agenerate(
self, self,
input_list: List[Dict[str, Any]], input_list: list[dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult: ) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
@ -170,9 +171,9 @@ class LLMChain(Chain):
) )
else: else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch( results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks} cast(list, prompts), {"callbacks": callbacks}
) )
generations: List[List[Generation]] = [] generations: list[list[Generation]] = []
for res in results: for res in results:
if isinstance(res, BaseMessage): if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)]) generations.append([ChatGeneration(message=res)])
@ -182,9 +183,9 @@ class LLMChain(Chain):
def prep_prompts( def prep_prompts(
self, self,
input_list: List[Dict[str, Any]], input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]: ) -> tuple[list[PromptValue], Optional[list[str]]]:
"""Prepare prompts from inputs.""" """Prepare prompts from inputs."""
stop = None stop = None
if len(input_list) == 0: if len(input_list) == 0:
@ -208,9 +209,9 @@ class LLMChain(Chain):
async def aprep_prompts( async def aprep_prompts(
self, self,
input_list: List[Dict[str, Any]], input_list: list[dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]: ) -> tuple[list[PromptValue], Optional[list[str]]]:
"""Prepare prompts from inputs.""" """Prepare prompts from inputs."""
stop = None stop = None
if len(input_list) == 0: if len(input_list) == 0:
@ -233,8 +234,8 @@ class LLMChain(Chain):
return prompts, stop return prompts, stop
def apply( def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
"""Utilize the LLM generate method for speed gains.""" """Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
@ -254,8 +255,8 @@ class LLMChain(Chain):
return outputs return outputs
async def aapply( async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]: ) -> list[dict[str, str]]:
"""Utilize the LLM generate method for speed gains.""" """Utilize the LLM generate method for speed gains."""
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
@ -278,7 +279,7 @@ class LLMChain(Chain):
def _run_output_key(self) -> str: def _run_output_key(self) -> str:
return self.output_key return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]: def create_outputs(self, llm_result: LLMResult) -> list[dict[str, Any]]:
"""Create outputs from response.""" """Create outputs from response."""
result = [ result = [
# Get the text of the top generated string. # Get the text of the top generated string.
@ -294,9 +295,9 @@ class LLMChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager) response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0] return self.create_outputs(response)[0]
@ -336,7 +337,7 @@ class LLMChain(Chain):
def predict_and_parse( def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]: ) -> Union[str, list[str], dict[str, Any]]:
"""Call predict and then parse the results.""" """Call predict and then parse the results."""
warnings.warn( warnings.warn(
"The predict_and_parse method is deprecated, " "The predict_and_parse method is deprecated, "
@ -350,7 +351,7 @@ class LLMChain(Chain):
async def apredict_and_parse( async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]: ) -> Union[str, list[str], dict[str, str]]:
"""Call apredict and then parse the results.""" """Call apredict and then parse the results."""
warnings.warn( warnings.warn(
"The apredict_and_parse method is deprecated, " "The apredict_and_parse method is deprecated, "
@ -363,8 +364,8 @@ class LLMChain(Chain):
return result return result
def apply_and_parse( def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]: ) -> Sequence[Union[str, list[str], dict[str, str]]]:
"""Call apply and then parse the results.""" """Call apply and then parse the results."""
warnings.warn( warnings.warn(
"The apply_and_parse method is deprecated, " "The apply_and_parse method is deprecated, "
@ -374,8 +375,8 @@ class LLMChain(Chain):
return self._parse_generation(result) return self._parse_generation(result)
def _parse_generation( def _parse_generation(
self, generation: List[Dict[str, str]] self, generation: list[dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]: ) -> Sequence[Union[str, list[str], dict[str, str]]]:
if self.prompt.output_parser is not None: if self.prompt.output_parser is not None:
return [ return [
self.prompt.output_parser.parse(res[self.output_key]) self.prompt.output_parser.parse(res[self.output_key])
@ -385,8 +386,8 @@ class LLMChain(Chain):
return generation return generation
async def aapply_and_parse( async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]: ) -> Sequence[Union[str, list[str], dict[str, str]]]:
"""Call apply and then parse the results.""" """Call apply and then parse the results."""
warnings.warn( warnings.warn(
"The aapply_and_parse method is deprecated, " "The aapply_and_parse method is deprecated, "

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
@ -107,7 +107,7 @@ class LLMCheckerChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values: if "llm" in values:
warnings.warn( warnings.warn(
"Directly instantiating an LLMCheckerChain with an llm is deprecated. " "Directly instantiating an LLMCheckerChain with an llm is deprecated. "
@ -135,7 +135,7 @@ class LLMCheckerChain(Chain):
return values return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the singular input key. """Return the singular input key.
:meta private: :meta private:
@ -143,7 +143,7 @@ class LLMCheckerChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return the singular output key. """Return the singular output key.
:meta private: :meta private:
@ -152,9 +152,9 @@ class LLMCheckerChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key] question = inputs[self.input_key]

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import math import math
import re import re
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -163,7 +163,7 @@ class LLMMathChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
try: try:
import numexpr # noqa: F401 import numexpr # noqa: F401
except ImportError: except ImportError:
@ -183,7 +183,7 @@ class LLMMathChain(Chain):
return values return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -191,7 +191,7 @@ class LLMMathChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Expect output key. """Expect output key.
:meta private: :meta private:
@ -221,7 +221,7 @@ class LLMMathChain(Chain):
def _process_llm_result( def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]: ) -> dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose) run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip() llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
@ -243,7 +243,7 @@ class LLMMathChain(Chain):
self, self,
llm_output: str, llm_output: str,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]: ) -> dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose) await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip() llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
@ -263,9 +263,9 @@ class LLMMathChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key]) _run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict( llm_output = self.llm_chain.predict(
@ -277,9 +277,9 @@ class LLMMathChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key]) await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict( llm_output = await self.llm_chain.apredict(

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
@ -112,7 +112,7 @@ class LLMSummarizationCheckerChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values: if "llm" in values:
warnings.warn( warnings.warn(
"Directly instantiating an LLMSummarizationCheckerChain with an llm is " "Directly instantiating an LLMSummarizationCheckerChain with an llm is "
@ -131,7 +131,7 @@ class LLMSummarizationCheckerChain(Chain):
return values return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return the singular input key. """Return the singular input key.
:meta private: :meta private:
@ -139,7 +139,7 @@ class LLMSummarizationCheckerChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return the singular output key. """Return the singular output key.
:meta private: :meta private:
@ -148,9 +148,9 @@ class LLMSummarizationCheckerChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
all_true = False all_true = False
count = 0 count = 0

View File

@ -702,7 +702,7 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
with open(file_path) as f: with open(file_path) as f:
config = json.load(f) config = json.load(f)
elif file_path.suffix.endswith((".yaml", ".yml")): elif file_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "r") as f: with open(file_path) as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
else: else:
raise ValueError("File type must be json or yaml") raise ValueError("File type must be json or yaml")

View File

@ -6,7 +6,8 @@ then combines the results with another one.
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional from collections.abc import Mapping
from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
@ -84,7 +85,7 @@ class MapReduceChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -92,7 +93,7 @@ class MapReduceChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -101,15 +102,15 @@ class MapReduceChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Split the larger text into smaller chunks. # Split the larger text into smaller chunks.
doc_text = inputs.pop(self.input_key) doc_text = inputs.pop(self.input_key)
texts = self.text_splitter.split_text(doc_text) texts = self.text_splitter.split_text(doc_text)
docs = [Document(page_content=text) for text in texts] docs = [Document(page_content=text) for text in texts]
_inputs: Dict[str, Any] = { _inputs: dict[str, Any] = {
**inputs, **inputs,
self.combine_documents_chain.input_key: docs, self.combine_documents_chain.input_key: docs,
} }

View File

@ -1,6 +1,6 @@
"""Pass input through a moderation endpoint.""" """Pass input through a moderation endpoint."""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -42,7 +42,7 @@ class OpenAIModerationChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_environment(cls, values: Dict) -> Any: def validate_environment(cls, values: dict) -> Any:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
@ -78,7 +78,7 @@ class OpenAIModerationChain(Chain):
return values return values
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -86,7 +86,7 @@ class OpenAIModerationChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -108,9 +108,9 @@ class OpenAIModerationChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
text = inputs[self.input_key] text = inputs[self.input_key]
if self.openai_pre_1_0: if self.openai_pre_1_0:
results = self.client.create(text) results = self.client.create(text)
@ -122,9 +122,9 @@ class OpenAIModerationChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if self.openai_pre_1_0: if self.openai_pre_1_0:
return await super()._acall(inputs, run_manager=run_manager) return await super()._acall(inputs, run_manager=run_manager)
text = inputs[self.input_key] text = inputs[self.input_key]

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.caches import BaseCache as BaseCache from langchain_core.caches import BaseCache as BaseCache
@ -68,7 +68,7 @@ class NatBotChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values: if "llm" in values:
warnings.warn( warnings.warn(
"Directly instantiating an NatBotChain with an llm is deprecated. " "Directly instantiating an NatBotChain with an llm is deprecated. "
@ -97,7 +97,7 @@ class NatBotChain(Chain):
return cls(llm_chain=llm_chain, objective=objective, **kwargs) return cls(llm_chain=llm_chain, objective=objective, **kwargs)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect url and browser content. """Expect url and browser content.
:meta private: :meta private:
@ -105,7 +105,7 @@ class NatBotChain(Chain):
return [self.input_url_key, self.input_browser_content_key] return [self.input_url_key, self.input_browser_content_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return command. """Return command.
:meta private: :meta private:
@ -114,9 +114,9 @@ class NatBotChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
url = inputs[self.input_url_key] url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key] browser_content = inputs[self.input_browser_content_key]

View File

@ -1,12 +1,10 @@
"""Methods for creating chains that use OpenAI function-calling APIs.""" """Methods for creating chains that use OpenAI function-calling APIs."""
from collections.abc import Sequence
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Optional, Optional,
Sequence,
Type,
Union, Union,
) )
@ -45,7 +43,7 @@ __all__ = [
@deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable") @deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable")
def create_openai_fn_chain( def create_openai_fn_chain(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
*, *,
@ -128,7 +126,7 @@ def create_openai_fn_chain(
raise ValueError("Need to pass in at least one function. Received zero.") raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions] openai_functions = [convert_to_openai_function(f) for f in functions]
output_parser = output_parser or get_openai_output_parser(functions) output_parser = output_parser or get_openai_output_parser(functions)
llm_kwargs: Dict[str, Any] = { llm_kwargs: dict[str, Any] = {
"functions": openai_functions, "functions": openai_functions,
} }
if len(openai_functions) == 1 and enforce_single_function_usage: if len(openai_functions) == 1 and enforce_single_function_usage:
@ -148,7 +146,7 @@ def create_openai_fn_chain(
since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output" since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output"
) )
def create_structured_output_chain( def create_structured_output_chain(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[dict[str, Any], type[BaseModel]],
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: BasePromptTemplate, prompt: BasePromptTemplate,
*, *,

View File

@ -1,4 +1,4 @@
from typing import Iterator, List from collections.abc import Iterator
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseChatModel, BaseLanguageModel from langchain_core.language_models import BaseChatModel, BaseLanguageModel
@ -21,7 +21,7 @@ class FactWithEvidence(BaseModel):
""" """
fact: str = Field(..., description="Body of the sentence, as part of a response") fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field( substring_quote: list[str] = Field(
..., ...,
description=( description=(
"Each source should be a direct quote from the context, " "Each source should be a direct quote from the context, "
@ -54,7 +54,7 @@ class QuestionAnswer(BaseModel):
each sentence contains a body and a list of sources.""" each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked") question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field( answer: list[FactWithEvidence] = Field(
..., ...,
description=( description=(
"Body of the answer, each fact should be " "Body of the answer, each fact should be "

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -83,7 +83,7 @@ def create_extraction_chain(
schema: dict, schema: dict,
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
verbose: bool = False, verbose: bool = False,
) -> Chain: ) -> Chain:
"""Creates a chain that extracts information from a passage. """Creates a chain that extracts information from a passage.
@ -170,7 +170,7 @@ def create_extraction_chain_pydantic(
""" """
class PydanticSchema(BaseModel): class PydanticSchema(BaseModel):
info: List[pydantic_schema] # type: ignore info: list[pydantic_schema] # type: ignore
if hasattr(pydantic_schema, "model_json_schema"): if hasattr(pydantic_schema, "model_json_schema"):
openai_schema = pydantic_schema.model_json_schema() openai_schema = pydantic_schema.model_json_schema()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import re import re
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import requests import requests
from langchain_core._api import deprecated from langchain_core._api import deprecated
@ -70,7 +70,7 @@ def _format_url(url: str, path_params: dict) -> str:
return url.format(**new_params) return url.format(**new_params)
def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -> dict: def _openapi_params_to_json_schema(params: list[Parameter], spec: OpenAPISpec) -> dict:
properties = {} properties = {}
required = [] required = []
for p in params: for p in params:
@ -89,7 +89,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
def openapi_spec_to_openai_fn( def openapi_spec_to_openai_fn(
spec: OpenAPISpec, spec: OpenAPISpec,
) -> Tuple[List[Dict[str, Any]], Callable]: ) -> tuple[list[dict[str, Any]], Callable]:
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI """Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
functions. functions.
@ -208,18 +208,18 @@ class SimpleRequestChain(Chain):
"""Key to use for the input of the request.""" """Key to use for the input of the request."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
return [self.output_key] return [self.output_key]
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
name = inputs[self.input_key].pop("name") name = inputs[self.input_key].pop("name")
@ -257,10 +257,10 @@ def get_openapi_chain(
llm: Optional[BaseLanguageModel] = None, llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
request_chain: Optional[Chain] = None, request_chain: Optional[Chain] = None,
llm_chain_kwargs: Optional[Dict] = None, llm_chain_kwargs: Optional[dict] = None,
verbose: bool = False, verbose: bool = False,
headers: Optional[Dict] = None, headers: Optional[dict] = None,
params: Optional[Dict] = None, params: Optional[dict] = None,
**kwargs: Any, **kwargs: Any,
) -> SequentialChain: ) -> SequentialChain:
"""Create a chain for querying an API from a OpenAPI spec. """Create a chain for querying an API from a OpenAPI spec.

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union, cast from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -21,7 +21,7 @@ class AnswerWithSources(BaseModel):
"""An answer to the question, with sources.""" """An answer to the question, with sources."""
answer: str = Field(..., description="Answer to the question that was asked") answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field( sources: list[str] = Field(
..., description="List of sources used to answer the question" ..., description="List of sources used to answer the question"
) )
@ -37,7 +37,7 @@ class AnswerWithSources(BaseModel):
) )
def create_qa_with_structure_chain( def create_qa_with_structure_chain(
llm: BaseLanguageModel, llm: BaseLanguageModel,
schema: Union[dict, Type[BaseModel]], schema: Union[dict, type[BaseModel]],
output_parser: str = "base", output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None, prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
verbose: bool = False, verbose: bool = False,

View File

@ -1,7 +1,7 @@
from typing import Any, Dict from typing import Any
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: def _resolve_schema_references(schema: Any, definitions: dict[str, Any]) -> Any:
""" """
Resolve the $ref keys in a JSON schema object using the provided definitions. Resolve the $ref keys in a JSON schema object using the provided definitions.
""" """

View File

@ -1,4 +1,4 @@
from typing import List, Type, Union from typing import Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -51,7 +51,7 @@ If a property is not present and is not required in the function parameters, do
), ),
) )
def create_extraction_chain_pydantic( def create_extraction_chain_pydantic(
pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]], pydantic_schemas: Union[list[type[BaseModel]], type[BaseModel]],
llm: BaseLanguageModel, llm: BaseLanguageModel,
system_message: str = _EXTRACTION_TEMPLATE, system_message: str = _EXTRACTION_TEMPLATE,
) -> Runnable: ) -> Runnable:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Tuple from typing import Callable
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
@ -21,8 +21,8 @@ class ConditionalPromptSelector(BasePromptSelector):
default_prompt: BasePromptTemplate default_prompt: BasePromptTemplate
"""Default prompt to use if no conditionals match.""" """Default prompt to use if no conditionals match."""
conditionals: List[ conditionals: list[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate] tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list) ] = Field(default_factory=list)
"""List of conditionals and prompts to use if the conditionals match.""" """List of conditionals and prompts to use if the conditionals match."""

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
@ -103,18 +103,18 @@ class QAGenerationChain(Chain):
raise NotImplementedError raise NotImplementedError
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
return [self.output_key] return [self.output_key]
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]: ) -> dict[str, list]:
docs = self.text_splitter.create_documents([inputs[self.input_key]]) docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate( results = self.llm_chain.generate(
[{"text": d.page_content} for d in docs], run_manager=run_manager [{"text": d.page_content} for d in docs], run_manager=run_manager

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect import inspect
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -103,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -111,7 +111,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
return [self.question_key] return [self.question_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -123,13 +123,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_naming(cls, values: Dict) -> Any: def validate_naming(cls, values: dict) -> Any:
"""Fix backwards compatibility in naming.""" """Fix backwards compatibility in naming."""
if "combine_document_chain" in values: if "combine_document_chain" in values:
values["combine_documents_chain"] = values.pop("combine_document_chain") values["combine_documents_chain"] = values.pop("combine_document_chain")
return values return values
def _split_sources(self, answer: str) -> Tuple[str, str]: def _split_sources(self, answer: str) -> tuple[str, str]:
"""Split sources from answer.""" """Split sources from answer."""
if re.search(r"SOURCES?:", answer, re.IGNORECASE): if re.search(r"SOURCES?:", answer, re.IGNORECASE):
answer, sources = re.split( answer, sources = re.split(
@ -143,17 +143,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
@abstractmethod @abstractmethod
def _get_docs( def _get_docs(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs to run questioning over.""" """Get docs to run questioning over."""
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = ( accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters "run_manager" in inspect.signature(self._get_docs).parameters
@ -167,7 +167,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
input_documents=docs, callbacks=_run_manager.get_child(), **inputs input_documents=docs, callbacks=_run_manager.get_child(), **inputs
) )
answer, sources = self._split_sources(answer) answer, sources = self._split_sources(answer)
result: Dict[str, Any] = { result: dict[str, Any] = {
self.answer_key: answer, self.answer_key: answer,
self.sources_answer_key: sources, self.sources_answer_key: sources,
} }
@ -178,17 +178,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
@abstractmethod @abstractmethod
async def _aget_docs( async def _aget_docs(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs to run questioning over.""" """Get docs to run questioning over."""
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = ( accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters "run_manager" in inspect.signature(self._aget_docs).parameters
@ -201,7 +201,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
input_documents=docs, callbacks=_run_manager.get_child(), **inputs input_documents=docs, callbacks=_run_manager.get_child(), **inputs
) )
answer, sources = self._split_sources(answer) answer, sources = self._split_sources(answer)
result: Dict[str, Any] = { result: dict[str, Any] = {
self.answer_key: answer, self.answer_key: answer,
self.sources_answer_key: sources, self.sources_answer_key: sources,
} }
@ -225,7 +225,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
input_docs_key: str = "docs" #: :meta private: input_docs_key: str = "docs" #: :meta private:
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -234,19 +234,19 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
def _get_docs( def _get_docs(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs to run questioning over.""" """Get docs to run questioning over."""
return inputs.pop(self.input_docs_key) return inputs.pop(self.input_docs_key)
async def _aget_docs( async def _aget_docs(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs to run questioning over.""" """Get docs to run questioning over."""
return inputs.pop(self.input_docs_key) return inputs.pop(self.input_docs_key)

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Mapping, Optional, Protocol from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel

View File

@ -1,6 +1,6 @@
"""Question-answering with sources over an index.""" """Question-answering with sources over an index."""
from typing import Any, Dict, List from typing import Any
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -25,7 +25,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
"""Restrict the docs to return from store based on tokens, """Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true""" enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs) num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance( if self.reduce_k_below_max_tokens and isinstance(
@ -43,8 +43,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs] return docs[:num_docs]
def _get_docs( def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]: ) -> list[Document]:
question = inputs[self.question_key] question = inputs[self.question_key]
docs = self.retriever.invoke( docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}
@ -52,8 +52,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return self._reduce_tokens_below_limit(docs) return self._reduce_tokens_below_limit(docs)
async def _aget_docs( async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]: ) -> list[Document]:
question = inputs[self.question_key] question = inputs[self.question_key]
docs = await self.retriever.ainvoke( docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}

View File

@ -1,7 +1,7 @@
"""Question-answering with sources over a vector database.""" """Question-answering with sources over a vector database."""
import warnings import warnings
from typing import Any, Dict, List from typing import Any
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -27,10 +27,10 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
max_tokens_limit: int = 3375 max_tokens_limit: int = 3375
"""Restrict the docs to return from store based on tokens, """Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true""" enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
search_kwargs: Dict[str, Any] = Field(default_factory=dict) search_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Extra search args.""" """Extra search args."""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs) num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance( if self.reduce_k_below_max_tokens and isinstance(
@ -48,8 +48,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs] return docs[:num_docs]
def _get_docs( def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]: ) -> list[Document]:
question = inputs[self.question_key] question = inputs[self.question_key]
docs = self.vectorstore.similarity_search( docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs question, k=self.k, **self.search_kwargs
@ -57,13 +57,13 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return self._reduce_tokens_below_limit(docs) return self._reduce_tokens_below_limit(docs)
async def _aget_docs( async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]: ) -> list[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async") raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
warnings.warn( warnings.warn(
"`VectorDBQAWithSourcesChain` is deprecated - " "`VectorDBQAWithSourcesChain` is deprecated - "
"please use `from langchain.chains import RetrievalQAWithSourcesChain`" "please use `from langchain.chains import RetrievalQAWithSourcesChain`"

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast from collections.abc import Sequence
from typing import Any, Callable, Optional, Union, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
@ -172,7 +173,7 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}") return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]: def construct_examples(input_output_pairs: Sequence[tuple[str, dict]]) -> list[dict]:
"""Construct examples from input-output pairs. """Construct examples from input-output pairs.
Args: Args:
@ -267,7 +268,7 @@ def load_query_constructor_chain(
llm: BaseLanguageModel, llm: BaseLanguageModel,
document_contents: str, document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]], attribute_info: Sequence[Union[AttributeInfo, dict]],
examples: Optional[List] = None, examples: Optional[list] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator), allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False, enable_limit: bool = False,

View File

@ -1,6 +1,7 @@
import datetime import datetime
import warnings import warnings
from typing import Any, Literal, Optional, Sequence, Union from collections.abc import Sequence
from typing import Any, Literal, Optional, Union
from langchain_core.utils import check_package_version from langchain_core.utils import check_package_version
from typing_extensions import TypedDict from typing_extensions import TypedDict

View File

@ -1,6 +1,7 @@
"""Load question answering chains.""" """Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager, Callbacks from langchain_core.callbacks import BaseCallbackManager, Callbacks

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Union from typing import Any, Union
from langchain_core.retrievers import ( from langchain_core.retrievers import (
BaseRetriever, BaseRetriever,
@ -11,7 +11,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
def create_retrieval_chain( def create_retrieval_chain(
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]], retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
combine_docs_chain: Runnable[Dict[str, Any], str], combine_docs_chain: Runnable[dict[str, Any], str],
) -> Runnable: ) -> Runnable:
"""Create retrieval chain that retrieves documents and then passes them on. """Create retrieval chain that retrieves documents and then passes them on.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect import inspect
import warnings import warnings
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -54,7 +54,7 @@ class BaseRetrievalQA(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Input keys. """Input keys.
:meta private: :meta private:
@ -62,7 +62,7 @@ class BaseRetrievalQA(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Output keys. """Output keys.
:meta private: :meta private:
@ -123,14 +123,14 @@ class BaseRetrievalQA(Chain):
question: str, question: str,
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get documents to do question answering over.""" """Get documents to do question answering over."""
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run get_relevant_text and llm on input query. """Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns If chain has 'return_source_documents' as 'True', returns
@ -166,14 +166,14 @@ class BaseRetrievalQA(Chain):
question: str, question: str,
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get documents to do question answering over.""" """Get documents to do question answering over."""
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run get_relevant_text and llm on input query. """Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns If chain has 'return_source_documents' as 'True', returns
@ -266,7 +266,7 @@ class RetrievalQA(BaseRetrievalQA):
question: str, question: str,
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
return self.retriever.invoke( return self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}
@ -277,7 +277,7 @@ class RetrievalQA(BaseRetrievalQA):
question: str, question: str,
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
return await self.retriever.ainvoke( return await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()} question, config={"callbacks": run_manager.get_child()}
@ -307,12 +307,12 @@ class VectorDBQA(BaseRetrievalQA):
"""Number of documents to query for.""" """Number of documents to query for."""
search_type: str = "similarity" search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`.""" """Search type to use over vectorstore. `similarity` or `mmr`."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict) search_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Extra search args.""" """Extra search args."""
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def raise_deprecation(cls, values: Dict) -> Any: def raise_deprecation(cls, values: dict) -> Any:
warnings.warn( warnings.warn(
"`VectorDBQA` is deprecated - " "`VectorDBQA` is deprecated - "
"please use `from langchain.chains import RetrievalQA`" "please use `from langchain.chains import RetrievalQA`"
@ -321,7 +321,7 @@ class VectorDBQA(BaseRetrievalQA):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_search_type(cls, values: Dict) -> Any: def validate_search_type(cls, values: dict) -> Any:
"""Validate search type.""" """Validate search type."""
if "search_type" in values: if "search_type" in values:
search_type = values["search_type"] search_type = values["search_type"]
@ -334,7 +334,7 @@ class VectorDBQA(BaseRetrievalQA):
question: str, question: str,
*, *,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
if self.search_type == "similarity": if self.search_type == "similarity":
docs = self.vectorstore.similarity_search( docs = self.vectorstore.similarity_search(
@ -353,7 +353,7 @@ class VectorDBQA(BaseRetrievalQA):
question: str, question: str,
*, *,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]: ) -> list[Document]:
"""Get docs.""" """Get docs."""
raise NotImplementedError("VectorDBQA does not support async") raise NotImplementedError("VectorDBQA does not support async")

View File

@ -3,7 +3,8 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional from collections.abc import Mapping
from typing import Any, NamedTuple, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -17,17 +18,17 @@ from langchain.chains.base import Chain
class Route(NamedTuple): class Route(NamedTuple):
destination: Optional[str] destination: Optional[str]
next_inputs: Dict[str, Any] next_inputs: dict[str, Any]
class RouterChain(Chain, ABC): class RouterChain(Chain, ABC):
"""Chain that outputs the name of a destination chain and the inputs to it.""" """Chain that outputs the name of a destination chain and the inputs to it."""
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
return ["destination", "next_inputs"] return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route: def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
""" """
Route inputs to a destination chain. Route inputs to a destination chain.
@ -42,7 +43,7 @@ class RouterChain(Chain, ABC):
return Route(result["destination"], result["next_inputs"]) return Route(result["destination"], result["next_inputs"])
async def aroute( async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None self, inputs: dict[str, Any], callbacks: Callbacks = None
) -> Route: ) -> Route:
result = await self.acall(inputs, callbacks=callbacks) result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"]) return Route(result["destination"], result["next_inputs"])
@ -67,7 +68,7 @@ class MultiRouteChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Will be whatever keys the router chain prompt expects. """Will be whatever keys the router chain prompt expects.
:meta private: :meta private:
@ -75,7 +76,7 @@ class MultiRouteChain(Chain):
return self.router_chain.input_keys return self.router_chain.input_keys
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Will always return text key. """Will always return text key.
:meta private: :meta private:
@ -84,9 +85,9 @@ class MultiRouteChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
route = self.router_chain.route(inputs, callbacks=callbacks) route = self.router_chain.route(inputs, callbacks=callbacks)
@ -109,9 +110,9 @@ class MultiRouteChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks) route = await self.router_chain.aroute(inputs, callbacks=callbacks)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -18,7 +19,7 @@ class EmbeddingRouterChain(RouterChain):
"""Chain that uses embeddings to route between options.""" """Chain that uses embeddings to route between options."""
vectorstore: VectorStore vectorstore: VectorStore
routing_keys: List[str] = ["query"] routing_keys: list[str] = ["query"]
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
@ -26,7 +27,7 @@ class EmbeddingRouterChain(RouterChain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects. """Will be whatever keys the LLM chain prompt expects.
:meta private: :meta private:
@ -35,18 +36,18 @@ class EmbeddingRouterChain(RouterChain):
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys]) _input = ", ".join([inputs[k] for k in self.routing_keys])
results = self.vectorstore.similarity_search(_input, k=1) results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]} return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys]) _input = ", ".join([inputs[k] for k in self.routing_keys])
results = await self.vectorstore.asimilarity_search(_input, k=1) results = await self.vectorstore.asimilarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]} return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@ -54,8 +55,8 @@ class EmbeddingRouterChain(RouterChain):
@classmethod @classmethod
def from_names_and_descriptions( def from_names_and_descriptions(
cls, cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]], names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
embeddings: Embeddings, embeddings: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> EmbeddingRouterChain: ) -> EmbeddingRouterChain:
@ -72,8 +73,8 @@ class EmbeddingRouterChain(RouterChain):
@classmethod @classmethod
async def afrom_names_and_descriptions( async def afrom_names_and_descriptions(
cls, cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]], names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore], vectorstore_cls: type[VectorStore],
embeddings: Embeddings, embeddings: Embeddings,
**kwargs: Any, **kwargs: Any,
) -> EmbeddingRouterChain: ) -> EmbeddingRouterChain:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast from typing import Any, Optional, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -114,42 +114,42 @@ class LLMRouterChain(RouterChain):
return self return self
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects. """Will be whatever keys the LLM chain prompt expects.
:meta private: :meta private:
""" """
return self.llm_chain.input_keys return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None: def _validate_outputs(self, outputs: dict[str, Any]) -> None:
super()._validate_outputs(outputs) super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict): if not isinstance(outputs["next_inputs"], dict):
raise ValueError raise ValueError
def _call( def _call(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs) prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
output = cast( output = cast(
Dict[str, Any], dict[str, Any],
self.llm_chain.prompt.output_parser.parse(prediction), self.llm_chain.prompt.output_parser.parse(prediction),
) )
return output return output
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
output = cast( output = cast(
Dict[str, Any], dict[str, Any],
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
) )
return output return output
@ -163,14 +163,14 @@ class LLMRouterChain(RouterChain):
return cls(llm_chain=llm_chain, **kwargs) return cls(llm_chain=llm_chain, **kwargs)
class RouterOutputParser(BaseOutputParser[Dict[str, str]]): class RouterOutputParser(BaseOutputParser[dict[str, str]]):
"""Parser for output of router chain in the multi-prompt chain.""" """Parser for output of router chain in the multi-prompt chain."""
default_destination: str = "DEFAULT" default_destination: str = "DEFAULT"
next_inputs_type: Type = str next_inputs_type: type = str
next_inputs_inner_key: str = "input" next_inputs_inner_key: str = "input"
def parse(self, text: str) -> Dict[str, Any]: def parse(self, text: str) -> dict[str, Any]:
try: try:
expected_keys = ["destination", "next_inputs"] expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys) parsed = parse_and_check_json_markdown(text, expected_keys)

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -142,14 +142,14 @@ class MultiPromptChain(MultiRouteChain):
""" # noqa: E501 """ # noqa: E501
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
return ["text"] return ["text"]
@classmethod @classmethod
def from_prompts( def from_prompts(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt_infos: List[Dict[str, str]], prompt_infos: list[dict[str, str]],
default_chain: Optional[Chain] = None, default_chain: Optional[Chain] = None,
**kwargs: Any, **kwargs: Any,
) -> MultiPromptChain: ) -> MultiPromptChain:

View File

@ -2,7 +2,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional from collections.abc import Mapping
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
@ -31,14 +32,14 @@ class MultiRetrievalQAChain(MultiRouteChain): # type: ignore[override]
"""Default chain to use when router doesn't map input to one of the destinations.""" """Default chain to use when router doesn't map input to one of the destinations."""
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
return ["result"] return ["result"]
@classmethod @classmethod
def from_retrievers( def from_retrievers(
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
retriever_infos: List[Dict[str, Any]], retriever_infos: list[dict[str, Any]],
default_retriever: Optional[BaseRetriever] = None, default_retriever: Optional[BaseRetriever] = None,
default_prompt: Optional[PromptTemplate] = None, default_prompt: Optional[PromptTemplate] = None,
default_chain: Optional[Chain] = None, default_chain: Optional[Chain] = None,

View File

@ -1,6 +1,6 @@
"""Chain pipeline where the outputs of one step feed directly into next.""" """Chain pipeline where the outputs of one step feed directly into next."""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -16,9 +16,9 @@ from langchain.chains.base import Chain
class SequentialChain(Chain): class SequentialChain(Chain):
"""Chain where the outputs of one chain feed directly into next.""" """Chain where the outputs of one chain feed directly into next."""
chains: List[Chain] chains: list[Chain]
input_variables: List[str] input_variables: list[str]
output_variables: List[str] #: :meta private: output_variables: list[str] #: :meta private:
return_all: bool = False return_all: bool = False
model_config = ConfigDict( model_config = ConfigDict(
@ -27,7 +27,7 @@ class SequentialChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Return expected input keys to the chain. """Return expected input keys to the chain.
:meta private: :meta private:
@ -35,7 +35,7 @@ class SequentialChain(Chain):
return self.input_variables return self.input_variables
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -44,7 +44,7 @@ class SequentialChain(Chain):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_chains(cls, values: Dict) -> Any: def validate_chains(cls, values: dict) -> Any:
"""Validate that the correct inputs exist for all chains.""" """Validate that the correct inputs exist for all chains."""
chains = values["chains"] chains = values["chains"]
input_variables = values["input_variables"] input_variables = values["input_variables"]
@ -97,9 +97,9 @@ class SequentialChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
known_values = inputs.copy() known_values = inputs.copy()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
@ -110,9 +110,9 @@ class SequentialChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
known_values = inputs.copy() known_values = inputs.copy()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
@ -127,7 +127,7 @@ class SequentialChain(Chain):
class SimpleSequentialChain(Chain): class SimpleSequentialChain(Chain):
"""Simple chain where the outputs of one step feed directly into next.""" """Simple chain where the outputs of one step feed directly into next."""
chains: List[Chain] chains: list[Chain]
strip_outputs: bool = False strip_outputs: bool = False
input_key: str = "input" #: :meta private: input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
@ -138,7 +138,7 @@ class SimpleSequentialChain(Chain):
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input key. """Expect input key.
:meta private: :meta private:
@ -146,7 +146,7 @@ class SimpleSequentialChain(Chain):
return [self.input_key] return [self.input_key]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output key. """Return output key.
:meta private: :meta private:
@ -171,9 +171,9 @@ class SimpleSequentialChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_input = inputs[self.input_key] _input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
@ -190,9 +190,9 @@ class SimpleSequentialChain(Chain):
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
_input = inputs[self.input_key] _input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
@ -27,7 +27,7 @@ class SQLInputWithTables(TypedDict):
"""Input for a SQL Chain.""" """Input for a SQL Chain."""
question: str question: str
table_names_to_use: List[str] table_names_to_use: list[str]
def create_sql_query_chain( def create_sql_query_chain(
@ -35,7 +35,7 @@ def create_sql_query_chain(
db: SQLDatabase, db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
k: int = 5, k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]: ) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
"""Create a chain that generates SQL queries. """Create a chain that generates SQL queries.
*Security Note*: This chain generates SQL queries for the given database. *Security Note*: This chain generates SQL queries for the given database.

View File

@ -1,5 +1,6 @@
import json import json
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union from collections.abc import Sequence
from typing import Any, Callable, Literal, Optional, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.output_parsers import ( from langchain_core.output_parsers import (
@ -63,7 +64,7 @@ from pydantic import BaseModel
), ),
) )
def create_openai_fn_runnable( def create_openai_fn_runnable(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
llm: Runnable, llm: Runnable,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
*, *,
@ -135,7 +136,7 @@ def create_openai_fn_runnable(
if not functions: if not functions:
raise ValueError("Need to pass in at least one function. Received zero.") raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions] openai_functions = [convert_to_openai_function(f) for f in functions]
llm_kwargs_: Dict[str, Any] = {"functions": openai_functions, **llm_kwargs} llm_kwargs_: dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
if len(openai_functions) == 1 and enforce_single_function_usage: if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]} llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]}
output_parser = output_parser or get_openai_output_parser(functions) output_parser = output_parser or get_openai_output_parser(functions)
@ -181,7 +182,7 @@ def create_openai_fn_runnable(
), ),
) )
def create_structured_output_runnable( def create_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
*, *,
@ -437,7 +438,7 @@ def create_structured_output_runnable(
def _create_openai_tools_runnable( def _create_openai_tools_runnable(
tool: Union[Dict[str, Any], Type[BaseModel], Callable], tool: Union[dict[str, Any], type[BaseModel], Callable],
llm: Runnable, llm: Runnable,
*, *,
prompt: Optional[BasePromptTemplate], prompt: Optional[BasePromptTemplate],
@ -446,7 +447,7 @@ def _create_openai_tools_runnable(
first_tool_only: bool, first_tool_only: bool,
) -> Runnable: ) -> Runnable:
oai_tool = convert_to_openai_tool(tool) oai_tool = convert_to_openai_tool(tool)
llm_kwargs: Dict[str, Any] = {"tools": [oai_tool]} llm_kwargs: dict[str, Any] = {"tools": [oai_tool]}
if enforce_tool_usage: if enforce_tool_usage:
llm_kwargs["tool_choice"] = { llm_kwargs["tool_choice"] = {
"type": "function", "type": "function",
@ -462,7 +463,7 @@ def _create_openai_tools_runnable(
def _get_openai_tool_output_parser( def _get_openai_tool_output_parser(
tool: Union[Dict[str, Any], Type[BaseModel], Callable], tool: Union[dict[str, Any], type[BaseModel], Callable],
*, *,
first_tool_only: bool = False, first_tool_only: bool = False,
) -> Union[BaseOutputParser, BaseGenerationOutputParser]: ) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
@ -479,7 +480,7 @@ def _get_openai_tool_output_parser(
def get_openai_output_parser( def get_openai_output_parser(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]], functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]: ) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions. """Get the appropriate function output parser given the user functions.
@ -496,7 +497,7 @@ def get_openai_output_parser(
""" """
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]): if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1: if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = { pydantic_schema: Union[dict, type[BaseModel]] = {
convert_to_openai_function(fn)["name"]: fn for fn in functions convert_to_openai_function(fn)["name"]: fn for fn in functions
} }
else: else:
@ -510,7 +511,7 @@ def get_openai_output_parser(
def _create_openai_json_runnable( def _create_openai_json_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
*, *,
@ -537,7 +538,7 @@ def _create_openai_json_runnable(
def _create_openai_functions_structured_output_runnable( def _create_openai_functions_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]], output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable, llm: Runnable,
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
*, *,

View File

@ -1,6 +1,7 @@
"""Load summarizing chains.""" """Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel

View File

@ -2,7 +2,8 @@
import functools import functools
import logging import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional from collections.abc import Awaitable
from typing import Any, Callable, Optional
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -26,13 +27,13 @@ class TransformChain(Chain):
output_variables["entities"], transform=func()) output_variables["entities"], transform=func())
""" """
input_variables: List[str] input_variables: list[str]
"""The keys expected by the transform's input dictionary.""" """The keys expected by the transform's input dictionary."""
output_variables: List[str] output_variables: list[str]
"""The keys returned by the transform's output dictionary.""" """The keys returned by the transform's output dictionary."""
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform") transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
"""The transform function.""" """The transform function."""
atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = ( atransform_cb: Optional[Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]] = (
Field(None, alias="atransform") Field(None, alias="atransform")
) )
"""The async coroutine transform function.""" """The async coroutine transform function."""
@ -47,7 +48,7 @@ class TransformChain(Chain):
logger.warning(msg) logger.warning(msg)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Expect input keys. """Expect input keys.
:meta private: :meta private:
@ -55,7 +56,7 @@ class TransformChain(Chain):
return self.input_variables return self.input_variables
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Return output keys. """Return output keys.
:meta private: :meta private:
@ -64,16 +65,16 @@ class TransformChain(Chain):
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> dict[str, str]:
return self.transform_cb(inputs) return self.transform_cb(inputs)
async def _acall( async def _acall(
self, self,
inputs: Dict[str, Any], inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if self.atransform_cb is not None: if self.atransform_cb is not None:
return await self.atransform_cb(inputs) return await self.atransform_cb(inputs)
else: else:

View File

@ -1,19 +1,13 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from collections.abc import AsyncIterator, Iterator, Sequence
from importlib import util from importlib import util
from typing import ( from typing import (
Any, Any,
AsyncIterator,
Callable, Callable,
Dict,
Iterator,
List,
Literal, Literal,
Optional, Optional,
Sequence,
Tuple,
Type,
Union, Union,
cast, cast,
overload, overload,
@ -73,7 +67,7 @@ def init_chat_model(
model: Optional[str] = None, model: Optional[str] = None,
*, *,
model_provider: Optional[str] = None, model_provider: Optional[str] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ..., configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
config_prefix: Optional[str] = None, config_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> _ConfigurableModel: ... ) -> _ConfigurableModel: ...
@ -87,7 +81,7 @@ def init_chat_model(
*, *,
model_provider: Optional[str] = None, model_provider: Optional[str] = None,
configurable_fields: Optional[ configurable_fields: Optional[
Union[Literal["any"], List[str], Tuple[str, ...]] Union[Literal["any"], list[str], tuple[str, ...]]
] = None, ] = None,
config_prefix: Optional[str] = None, config_prefix: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
@ -514,7 +508,7 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None return None
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]: def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
if ( if (
not model_provider not model_provider
and ":" in model and ":" in model
@ -554,12 +548,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
self, self,
*, *,
default_config: Optional[dict] = None, default_config: Optional[dict] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any", configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
config_prefix: str = "", config_prefix: str = "",
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (), queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None: ) -> None:
self._default_config: dict = default_config or {} self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], List[str]] = ( self._configurable_fields: Union[Literal["any"], list[str]] = (
configurable_fields configurable_fields
if configurable_fields == "any" if configurable_fields == "any"
else list(configurable_fields) else list(configurable_fields)
@ -569,7 +563,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if config_prefix and not config_prefix.endswith("_") if config_prefix and not config_prefix.endswith("_")
else config_prefix else config_prefix
) )
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list( self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
queued_declarative_operations queued_declarative_operations
) )
@ -670,7 +664,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
return Union[ return Union[
str, str,
Union[StringPromptValue, ChatPromptValueConcrete], Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage], list[AnyMessage],
] ]
def invoke( def invoke(
@ -708,12 +702,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch( def batch(
self, self,
inputs: List[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Any]: ) -> list[Any]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1: if config is None or isinstance(config, dict) or len(config) <= 1:
@ -731,12 +725,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch( async def abatch(
self, self,
inputs: List[LanguageModelInput], inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Any]: ) -> list[Any]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1: if config is None or isinstance(config, dict) or len(config) <= 1:
@ -759,7 +753,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[int, Union[Any, Exception]]]: ) -> Iterator[tuple[int, Union[Any, Exception]]]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1: if config is None or isinstance(config, dict) or len(config) <= 1:
@ -782,7 +776,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[Tuple[int, Any]]: ) -> AsyncIterator[tuple[int, Any]]:
config = config or None config = config or None
# If <= 1 config use the underlying models batch implementation. # If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1: if config is None or isinstance(config, dict) or len(config) <= 1:
@ -808,8 +802,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> Iterator[Any]: ) -> Iterator[Any]:
for x in self._model(config).transform(input, config=config, **kwargs): yield from self._model(config).transform(input, config=config, **kwargs)
yield x
async def atransform( async def atransform(
self, self,
@ -915,13 +908,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
# Explicitly added to satisfy downstream linters. # Explicitly added to satisfy downstream linters.
def bind_tools( def bind_tools(
self, self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs) return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters. # Explicitly added to satisfy downstream linters.
def with_structured_output( def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs) return self.__getattr__("with_structured_output")(schema, **kwargs)

View File

@ -1,6 +1,6 @@
import functools import functools
from importlib import util from importlib import util
from typing import Any, List, Optional, Tuple, Union from typing import Any, Optional, Union
from langchain_core._api import beta from langchain_core._api import beta
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
@ -25,7 +25,7 @@ def _get_provider_list() -> str:
) )
def _parse_model_string(model_name: str) -> Tuple[str, str]: def _parse_model_string(model_name: str) -> tuple[str, str]:
"""Parse a model string into provider and model name components. """Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider The model string should be in the format 'provider:model-name', where provider
@ -78,7 +78,7 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]:
def _infer_model_and_provider( def _infer_model_and_provider(
model: str, *, provider: Optional[str] = None model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]: ) -> tuple[str, str]:
if not model.strip(): if not model.strip():
raise ValueError("Model name cannot be empty") raise ValueError("Model name cannot be empty")
if provider is None and ":" in model: if provider is None and ":" in model:
@ -122,7 +122,7 @@ def init_embeddings(
*, *,
provider: Optional[str] = None, provider: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]: ) -> Union[Embeddings, Runnable[Any, list[float]]]:
"""Initialize an embeddings model from a model name and optional provider. """Initialize an embeddings model from a model name and optional provider.
**Note:** Must have the integration package corresponding to the model provider **Note:** Must have the integration package corresponding to the model provider

View File

@ -12,8 +12,9 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import uuid import uuid
from collections.abc import Sequence
from functools import partial from functools import partial
from typing import Callable, List, Optional, Sequence, Union, cast from typing import Callable, Optional, Union, cast
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
@ -45,9 +46,9 @@ def _value_serializer(value: Sequence[float]) -> bytes:
return json.dumps(value).encode() return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> List[float]: def _value_deserializer(serialized_value: bytes) -> list[float]:
"""Deserialize a value.""" """Deserialize a value."""
return cast(List[float], json.loads(serialized_value.decode())) return cast(list[float], json.loads(serialized_value.decode()))
class CacheBackedEmbeddings(Embeddings): class CacheBackedEmbeddings(Embeddings):
@ -88,10 +89,10 @@ class CacheBackedEmbeddings(Embeddings):
def __init__( def __init__(
self, self,
underlying_embeddings: Embeddings, underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]], document_embedding_store: BaseStore[str, list[float]],
*, *,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None, query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
) -> None: ) -> None:
"""Initialize the embedder. """Initialize the embedder.
@ -108,7 +109,7 @@ class CacheBackedEmbeddings(Embeddings):
self.underlying_embeddings = underlying_embeddings self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts. """Embed a list of texts.
The method first checks the cache for the embeddings. The method first checks the cache for the embeddings.
@ -121,10 +122,10 @@ class CacheBackedEmbeddings(Embeddings):
Returns: Returns:
A list of embeddings for the given texts. A list of embeddings for the given texts.
""" """
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget( vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
texts texts
) )
all_missing_indices: List[int] = [ all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None i for i, vector in enumerate(vectors) if vector is None
] ]
@ -138,10 +139,10 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
List[List[float]], vectors list[list[float]], vectors
) # Nones should have been resolved by now ) # Nones should have been resolved by now
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts. """Embed a list of texts.
The method first checks the cache for the embeddings. The method first checks the cache for the embeddings.
@ -154,10 +155,10 @@ class CacheBackedEmbeddings(Embeddings):
Returns: Returns:
A list of embeddings for the given texts. A list of embeddings for the given texts.
""" """
vectors: List[ vectors: list[
Union[List[float], None] Union[list[float], None]
] = await self.document_embedding_store.amget(texts) ] = await self.document_embedding_store.amget(texts)
all_missing_indices: List[int] = [ all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None i for i, vector in enumerate(vectors) if vector is None
] ]
@ -175,10 +176,10 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector vectors[index] = updated_vector
return cast( return cast(
List[List[float]], vectors list[list[float]], vectors
) # Nones should have been resolved by now ) # Nones should have been resolved by now
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> list[float]:
"""Embed query text. """Embed query text.
By default, this method does not cache queries. To enable caching, set the By default, this method does not cache queries. To enable caching, set the
@ -201,7 +202,7 @@ class CacheBackedEmbeddings(Embeddings):
self.query_embedding_store.mset([(text, vector)]) self.query_embedding_store.mset([(text, vector)])
return vector return vector
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> list[float]:
"""Embed query text. """Embed query text.
By default, this method does not cache queries. To enable caching, set the By default, this method does not cache queries. To enable caching, set the
@ -250,7 +251,7 @@ class CacheBackedEmbeddings(Embeddings):
""" """
namespace = namespace namespace = namespace
key_encoder = _create_key_encoder(namespace) key_encoder = _create_key_encoder(namespace)
document_embedding_store = EncoderBackedStore[str, List[float]]( document_embedding_store = EncoderBackedStore[str, list[float]](
document_embedding_cache, document_embedding_cache,
key_encoder, key_encoder,
_value_serializer, _value_serializer,
@ -261,7 +262,7 @@ class CacheBackedEmbeddings(Embeddings):
elif query_embedding_cache is False: elif query_embedding_cache is False:
query_embedding_store = None query_embedding_store = None
else: else:
query_embedding_store = EncoderBackedStore[str, List[float]]( query_embedding_store = EncoderBackedStore[str, list[float]](
query_embedding_cache, query_embedding_cache,
key_encoder, key_encoder,
_value_serializer, _value_serializer,

View File

@ -6,13 +6,10 @@ chain (LLMChain) to generate the reasoning and scores.
""" """
import re import re
from collections.abc import Sequence
from typing import ( from typing import (
Any, Any,
Dict,
List,
Optional, Optional,
Sequence,
Tuple,
TypedDict, TypedDict,
Union, Union,
cast, cast,
@ -145,7 +142,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
# 0 # 0
""" """
agent_tools: Optional[List[BaseTool]] = None agent_tools: Optional[list[BaseTool]] = None
"""A list of tools available to the agent.""" """A list of tools available to the agent."""
eval_chain: LLMChain eval_chain: LLMChain
"""The language model chain used for evaluation.""" """The language model chain used for evaluation."""
@ -184,7 +181,7 @@ Description: {tool.description}"""
@staticmethod @staticmethod
def get_agent_trajectory( def get_agent_trajectory(
steps: Union[str, Sequence[Tuple[AgentAction, str]]], steps: Union[str, Sequence[tuple[AgentAction, str]]],
) -> str: ) -> str:
"""Get the agent trajectory as a formatted string. """Get the agent trajectory as a formatted string.
@ -263,7 +260,7 @@ The following is the expected answer. Use this to measure correctness:
) )
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> list[str]:
"""Get the input keys for the chain. """Get the input keys for the chain.
Returns: Returns:
@ -272,7 +269,7 @@ The following is the expected answer. Use this to measure correctness:
return ["question", "agent_trajectory", "answer", "reference"] return ["question", "agent_trajectory", "answer", "reference"]
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> list[str]:
"""Get the output keys for the chain. """Get the output keys for the chain.
Returns: Returns:
@ -280,16 +277,16 @@ The following is the expected answer. Use this to measure correctness:
""" """
return ["score", "reasoning"] return ["score", "reasoning"]
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Validate and prep inputs.""" """Validate and prep inputs."""
inputs["reference"] = self._format_reference(inputs.get("reference")) inputs["reference"] = self._format_reference(inputs.get("reference"))
return super().prep_inputs(inputs) return super().prep_inputs(inputs)
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run the chain and generate the output. """Run the chain and generate the output.
Args: Args:
@ -311,9 +308,9 @@ The following is the expected answer. Use this to measure correctness:
async def _acall( async def _acall(
self, self,
inputs: Dict[str, str], inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Run the chain and generate the output. """Run the chain and generate the output.
Args: Args:
@ -338,11 +335,11 @@ The following is the expected answer. Use this to measure correctness:
*, *,
prediction: str, prediction: str,
input: str, input: str,
agent_trajectory: Sequence[Tuple[AgentAction, str]], agent_trajectory: Sequence[tuple[AgentAction, str]],
reference: Optional[str] = None, reference: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False, include_run_info: bool = False,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
@ -380,11 +377,11 @@ The following is the expected answer. Use this to measure correctness:
*, *,
prediction: str, prediction: str,
input: str, input: str,
agent_trajectory: Sequence[Tuple[AgentAction, str]], agent_trajectory: Sequence[tuple[AgentAction, str]],
reference: Optional[str] = None, reference: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False, include_run_info: bool = False,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Union from typing import Any, Optional, Union
from langchain_core.callbacks.manager import Callbacks from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
@ -49,7 +49,7 @@ _SUPPORTED_CRITERIA = {
def resolve_pairwise_criteria( def resolve_pairwise_criteria(
criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]], criteria: Optional[Union[CRITERIA_TYPE, str, list[CRITERIA_TYPE]]],
) -> dict: ) -> dict:
"""Resolve the criteria for the pairwise evaluator. """Resolve the criteria for the pairwise evaluator.
@ -113,7 +113,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): # type: ignore[
""" """
return "pairwise_string_result" return "pairwise_string_result"
def parse(self, text: str) -> Dict[str, Any]: def parse(self, text: str) -> dict[str, Any]:
"""Parse the output text. """Parse the output text.
Args: Args:
@ -314,8 +314,8 @@ Performance may be significantly worse with other models."
input: Optional[str] = None, input: Optional[str] = None,
reference: Optional[str] = None, reference: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False, include_run_info: bool = False,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:
@ -356,8 +356,8 @@ Performance may be significantly worse with other models."
reference: Optional[str] = None, reference: Optional[str] = None,
input: Optional[str] = None, input: Optional[str] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
tags: Optional[List[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False, include_run_info: bool = False,
**kwargs: Any, **kwargs: Any,
) -> dict: ) -> dict:

Some files were not shown because too many files have changed in this diff Show More