1
0
mirror of https://github.com/hwchase17/langchain.git synced 2025-05-05 07:08:03 +00:00

langchain[lint]: use pyupgrade to get to 3.9 standards ()

This commit is contained in:
Sydney Runkle 2025-04-11 10:33:26 -04:00 committed by GitHub
parent d9b628e764
commit 48affc498b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
214 changed files with 1530 additions and 1554 deletions
libs/langchain/langchain
_api
agents
callbacks
chains
chat_models
embeddings
evaluation

View File

@ -1,5 +1,5 @@
import importlib
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
from langchain_core._api import internal, warn_deprecated
@ -15,8 +15,8 @@ ALLOWED_TOP_LEVEL_PKGS = {
def create_importer(
package: str,
*,
module_lookup: Optional[Dict[str, str]] = None,
deprecated_lookups: Optional[Dict[str, str]] = None,
module_lookup: Optional[dict[str, str]] = None,
deprecated_lookups: Optional[dict[str, str]] = None,
fallback_module: Optional[str] = None,
) -> Callable[[str], Any]:
"""Create a function that helps retrieve objects from their new locations.

View File

@ -3,21 +3,17 @@
from __future__ import annotations
import asyncio
import builtins
import json
import logging
import time
from abc import abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from pathlib import Path
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
@ -62,17 +58,17 @@ class BaseSingleActionAgent(BaseModel):
"""Base Single Action Agent class."""
@property
def return_values(self) -> List[str]:
def return_values(self) -> list[str]:
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
def get_allowed_tools(self) -> Optional[list[str]]:
return None
@abstractmethod
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -91,7 +87,7 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -109,7 +105,7 @@ class BaseSingleActionAgent(BaseModel):
@property
@abstractmethod
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
:meta private:
@ -118,7 +114,7 @@ class BaseSingleActionAgent(BaseModel):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -171,7 +167,7 @@ class BaseSingleActionAgent(BaseModel):
"""Return Identifier of an agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.
Returns:
@ -223,7 +219,7 @@ class BaseSingleActionAgent(BaseModel):
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run."""
return {}
@ -232,11 +228,11 @@ class BaseMultiActionAgent(BaseModel):
"""Base Multi Action Agent class."""
@property
def return_values(self) -> List[str]:
def return_values(self) -> list[str]:
"""Return values of the agent."""
return ["output"]
def get_allowed_tools(self) -> Optional[List[str]]:
def get_allowed_tools(self) -> Optional[list[str]]:
"""Get allowed tools.
Returns:
@ -247,10 +243,10 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
@ -266,10 +262,10 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Async given input, decided what to do.
Args:
@ -284,7 +280,7 @@ class BaseMultiActionAgent(BaseModel):
@property
@abstractmethod
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
:meta private:
@ -293,7 +289,7 @@ class BaseMultiActionAgent(BaseModel):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -323,7 +319,7 @@ class BaseMultiActionAgent(BaseModel):
"""Return Identifier of an agent type."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().model_dump()
try:
@ -371,7 +367,7 @@ class BaseMultiActionAgent(BaseModel):
else:
raise ValueError(f"{save_path} must be json or yaml")
def tool_run_logging_kwargs(self) -> Dict:
def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run."""
return {}
@ -386,7 +382,7 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
class MultiActionAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
BaseOutputParser[Union[list[AgentAction], AgentFinish]]
):
"""Base class for parsing agent output into agent actions/finish.
@ -394,7 +390,7 @@ class MultiActionAgentOutputParser(
"""
@abstractmethod
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
"""Parse text into agent actions/finish.
Args:
@ -411,8 +407,8 @@ class RunnableAgent(BaseSingleActionAgent):
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
"""Runnable to call to get agent action."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
input_keys_arg: list[str] = []
return_keys_arg: list[str] = []
stream_runnable: bool = True
"""Whether to stream from the runnable or not.
@ -427,18 +423,18 @@ class RunnableAgent(BaseSingleActionAgent):
)
@property
def return_values(self) -> List[str]:
def return_values(self) -> list[str]:
"""Return values of the agent."""
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys."""
return self.input_keys_arg
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -474,7 +470,7 @@ class RunnableAgent(BaseSingleActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[
@ -518,10 +514,10 @@ class RunnableAgent(BaseSingleActionAgent):
class RunnableMultiActionAgent(BaseMultiActionAgent):
"""Agent powered by Runnables."""
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
runnable: Runnable[dict, Union[list[AgentAction], AgentFinish]]
"""Runnable to call to get agent actions."""
input_keys_arg: List[str] = []
return_keys_arg: List[str] = []
input_keys_arg: list[str] = []
return_keys_arg: list[str] = []
stream_runnable: bool = True
"""Whether to stream from the runnable or not.
@ -536,12 +532,12 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
)
@property
def return_values(self) -> List[str]:
def return_values(self) -> list[str]:
"""Return values of the agent."""
return self.return_keys_arg
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
Returns:
@ -551,11 +547,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[
List[AgentAction],
list[AgentAction],
AgentFinish,
]:
"""Based on past history and current inputs, decide what to do.
@ -590,11 +586,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[
List[AgentAction],
list[AgentAction],
AgentFinish,
]:
"""Async based on past history and current inputs, decide what to do.
@ -644,11 +640,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
"""LLMChain to use for agent."""
output_parser: AgentOutputParser
"""Output parser to use for agent."""
stop: List[str]
stop: list[str]
"""List of strings to stop on."""
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
Returns:
@ -656,7 +652,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
"""
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
@ -664,7 +660,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -689,7 +685,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -712,7 +708,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
)
return self.output_parser.parse(output)
def tool_run_logging_kwargs(self) -> Dict:
def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run."""
return {
"llm_prefix": "",
@ -737,21 +733,21 @@ class Agent(BaseSingleActionAgent):
"""LLMChain to use for agent."""
output_parser: AgentOutputParser
"""Output parser to use for agent."""
allowed_tools: Optional[List[str]] = None
allowed_tools: Optional[list[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed."""
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
return _dict
def get_allowed_tools(self) -> Optional[List[str]]:
def get_allowed_tools(self) -> Optional[list[str]]:
"""Get allowed tools."""
return self.allowed_tools
@property
def return_values(self) -> List[str]:
def return_values(self) -> list[str]:
"""Return values of the agent."""
return ["output"]
@ -767,15 +763,15 @@ class Agent(BaseSingleActionAgent):
raise ValueError("fix_text not implemented for this agent.")
@property
def _stop(self) -> List[str]:
def _stop(self) -> list[str]:
return [
f"\n{self.observation_prefix.rstrip()}",
f"\n\t{self.observation_prefix.rstrip()}",
]
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> Union[str, List[BaseMessage]]:
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> Union[str, list[BaseMessage]]:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
for action, observation in intermediate_steps:
@ -785,7 +781,7 @@ class Agent(BaseSingleActionAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -806,7 +802,7 @@ class Agent(BaseSingleActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -827,8 +823,8 @@ class Agent(BaseSingleActionAgent):
return agent_output
def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Dict[str, Any]:
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
) -> builtins.dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps.
Args:
@ -845,7 +841,7 @@ class Agent(BaseSingleActionAgent):
return full_inputs
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
:meta private:
@ -957,7 +953,7 @@ class Agent(BaseSingleActionAgent):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -1009,7 +1005,7 @@ class Agent(BaseSingleActionAgent):
f"got {early_stopping_method}"
)
def tool_run_logging_kwargs(self) -> Dict:
def tool_run_logging_kwargs(self) -> builtins.dict:
"""Return logging kwargs for tool run."""
return {
"llm_prefix": self.llm_prefix,
@ -1040,7 +1036,7 @@ class ExceptionTool(BaseTool): # type: ignore[override]
return query
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
NextStepOutput = list[Union[AgentFinish, AgentAction, AgentStep]]
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
@ -1086,7 +1082,7 @@ class AgentExecutor(Chain):
as an observation.
"""
trim_intermediate_steps: Union[
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
int, Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]]
] = -1
"""How to trim the intermediate steps before returning them.
Defaults to -1, which means no trimming.
@ -1144,7 +1140,7 @@ class AgentExecutor(Chain):
@model_validator(mode="before")
@classmethod
def validate_runnable_agent(cls, values: Dict) -> Any:
def validate_runnable_agent(cls, values: dict) -> Any:
"""Convert runnable to agent if passed in.
Args:
@ -1160,7 +1156,7 @@ class AgentExecutor(Chain):
except Exception as _:
multi_action = False
else:
multi_action = output_type == Union[List[AgentAction], AgentFinish]
multi_action = output_type == Union[list[AgentAction], AgentFinish]
stream_runnable = values.pop("stream_runnable", True)
if multi_action:
@ -1239,7 +1235,7 @@ class AgentExecutor(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the input keys.
:meta private:
@ -1247,7 +1243,7 @@ class AgentExecutor(Chain):
return self._action_agent.input_keys
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return the singular output key.
:meta private:
@ -1284,7 +1280,7 @@ class AgentExecutor(Chain):
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if run_manager:
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values
@ -1297,7 +1293,7 @@ class AgentExecutor(Chain):
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if run_manager:
await run_manager.on_agent_finish(
output, color="green", verbose=self.verbose
@ -1309,7 +1305,7 @@ class AgentExecutor(Chain):
def _consume_next_step(
self, values: NextStepOutput
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
if isinstance(values[-1], AgentFinish):
assert len(values) == 1
return values[-1]
@ -1320,12 +1316,12 @@ class AgentExecutor(Chain):
def _take_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step(
[
a
@ -1341,10 +1337,10 @@ class AgentExecutor(Chain):
def _iter_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop.
@ -1404,7 +1400,7 @@ class AgentExecutor(Chain):
yield output
return
actions: List[AgentAction]
actions: list[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
@ -1418,8 +1414,8 @@ class AgentExecutor(Chain):
def _perform_agent_action(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
agent_action: AgentAction,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> AgentStep:
@ -1457,12 +1453,12 @@ class AgentExecutor(Chain):
async def _atake_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
return self._consume_next_step(
[
a
@ -1478,10 +1474,10 @@ class AgentExecutor(Chain):
async def _aiter_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
inputs: dict[str, str],
intermediate_steps: list[tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop.
@ -1539,7 +1535,7 @@ class AgentExecutor(Chain):
yield output
return
actions: List[AgentAction]
actions: list[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
@ -1563,8 +1559,8 @@ class AgentExecutor(Chain):
async def _aperform_agent_action(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
name_to_tool_map: dict[str, BaseTool],
color_mapping: dict[str, str],
agent_action: AgentAction,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> AgentStep:
@ -1604,9 +1600,9 @@ class AgentExecutor(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -1614,7 +1610,7 @@ class AgentExecutor(Chain):
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
intermediate_steps: list[tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed
iterations = 0
time_elapsed = 0.0
@ -1651,9 +1647,9 @@ class AgentExecutor(Chain):
async def _acall(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Async run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -1661,7 +1657,7 @@ class AgentExecutor(Chain):
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
intermediate_steps: list[tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed
iterations = 0
time_elapsed = 0.0
@ -1712,7 +1708,7 @@ class AgentExecutor(Chain):
)
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
self, next_step_output: tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
@ -1730,8 +1726,8 @@ class AgentExecutor(Chain):
return None
def _prepare_intermediate_steps(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> list[tuple[AgentAction, str]]:
if (
isinstance(self.trim_intermediate_steps, int)
and self.trim_intermediate_steps > 0
@ -1744,7 +1740,7 @@ class AgentExecutor(Chain):
def stream(
self,
input: Union[Dict[str, Any], Any],
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[AddableDict]:
@ -1770,12 +1766,11 @@ class AgentExecutor(Chain):
yield_actions=True,
**kwargs,
)
for step in iterator:
yield step
yield from iterator
async def astream(
self,
input: Union[Dict[str, Any], Any],
input: Union[dict[str, Any], Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[AddableDict]:

View File

@ -3,15 +3,11 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import AsyncIterator, Iterator
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
from uuid import UUID
@ -53,7 +49,7 @@ class AgentExecutorIterator:
callbacks: Callbacks = None,
*,
tags: Optional[list[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
run_id: Optional[UUID] = None,
include_run_info: bool = False,
@ -90,17 +86,17 @@ class AgentExecutorIterator:
self.yield_actions = yield_actions
self.reset()
_inputs: Dict[str, str]
_inputs: dict[str, str]
callbacks: Callbacks
tags: Optional[list[str]]
metadata: Optional[Dict[str, Any]]
metadata: Optional[dict[str, Any]]
run_name: Optional[str]
run_id: Optional[UUID]
include_run_info: bool
yield_actions: bool
@property
def inputs(self) -> Dict[str, str]:
def inputs(self) -> dict[str, str]:
"""The inputs to the AgentExecutor."""
return self._inputs
@ -120,12 +116,12 @@ class AgentExecutorIterator:
self.inputs = self.inputs
@property
def name_to_tool_map(self) -> Dict[str, BaseTool]:
def name_to_tool_map(self) -> dict[str, BaseTool]:
"""A mapping of tool names to tools."""
return {tool.name: tool for tool in self.agent_executor.tools}
@property
def color_mapping(self) -> Dict[str, str]:
def color_mapping(self) -> dict[str, str]:
"""A mapping of tool names to colors."""
return get_color_mapping(
[tool.name for tool in self.agent_executor.tools],
@ -156,7 +152,7 @@ class AgentExecutorIterator:
def make_final_outputs(
self,
outputs: Dict[str, Any],
outputs: dict[str, Any],
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun],
) -> AddableDict:
# have access to intermediate steps by design in iterator,
@ -171,7 +167,7 @@ class AgentExecutorIterator:
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return prepared_outputs
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]:
def __iter__(self: AgentExecutorIterator) -> Iterator[AddableDict]:
logger.debug("Initialising AgentExecutorIterator")
self.reset()
callback_manager = CallbackManager.configure(
@ -311,7 +307,7 @@ class AgentExecutorIterator:
def _process_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
run_manager: CallbackManagerForChainRun,
) -> AddableDict:
"""
@ -339,7 +335,7 @@ class AgentExecutorIterator:
async def _aprocess_next_step_output(
self,
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
run_manager: AsyncCallbackManagerForChainRun,
) -> AddableDict:
"""

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.memory import BaseMemory
@ -26,7 +26,7 @@ def _get_default_system_message() -> SystemMessage:
def create_conversational_retrieval_agent(
llm: BaseLanguageModel,
tools: List[BaseTool],
tools: list[BaseTool],
remember_intermediate_steps: bool = True,
memory_key: str = "chat_history",
system_message: Optional[SystemMessage] = None,

View File

@ -1,6 +1,6 @@
"""VectorStore agent."""
from typing import Any, Dict, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks.base import BaseCallbackManager
@ -36,7 +36,7 @@ def create_vectorstore_agent(
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
agent_executor_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a VectorStore agent from an LLM and tools.
@ -129,7 +129,7 @@ def create_vectorstore_router_agent(
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = ROUTER_PREFIX,
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
agent_executor_kwargs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> AgentExecutor:
"""Construct a VectorStore router agent from an LLM and tools.

View File

@ -1,7 +1,5 @@
"""Toolkit for interacting with a vector store."""
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
@ -31,7 +29,7 @@ class VectorStoreToolkit(BaseToolkit):
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""
try:
from langchain_community.tools.vectorstore.tool import (
@ -66,16 +64,16 @@ class VectorStoreToolkit(BaseToolkit):
class VectorStoreRouterToolkit(BaseToolkit):
"""Toolkit for routing between Vector Stores."""
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
vectorstores: list[VectorStoreInfo] = Field(exclude=True)
llm: BaseLanguageModel
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_tools(self) -> List[BaseTool]:
def get_tools(self) -> list[BaseTool]:
"""Get the tools in the toolkit."""
tools: List[BaseTool] = []
tools: list[BaseTool] = []
try:
from langchain_community.tools.vectorstore.tool import (
VectorStoreQATool,

View File

@ -1,4 +1,5 @@
from typing import Any, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction
@ -48,7 +49,7 @@ class ChatAgent(Agent):
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
@ -72,7 +73,7 @@ class ChatAgent(Agent):
validate_tools_single_input(class_name=cls.__name__, tools=tools)
@property
def _stop(self) -> List[str]:
def _stop(self) -> list[str]:
return ["Observation:"]
@classmethod
@ -83,7 +84,7 @@ class ChatAgent(Agent):
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
human_message: str = HUMAN_MESSAGE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> BasePromptTemplate:
"""Create a prompt from a list of tools.
@ -132,7 +133,7 @@ class ChatAgent(Agent):
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
human_message: str = HUMAN_MESSAGE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools.

View File

@ -1,6 +1,7 @@
import json
import re
from typing import Pattern, Union
from re import Pattern
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, List, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager
@ -71,7 +72,7 @@ class ConversationalAgent(Agent):
format_instructions: str = FORMAT_INSTRUCTIONS,
ai_prefix: str = "AI",
human_prefix: str = "Human",
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero-shot agent.
@ -120,7 +121,7 @@ class ConversationalAgent(Agent):
format_instructions: str = FORMAT_INSTRUCTIONS,
ai_prefix: str = "AI",
human_prefix: str = "Human",
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools.

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction
@ -77,7 +78,7 @@ class ConversationalChatAgent(Agent):
tools: Sequence[BaseTool],
system_message: str = PREFIX,
human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
output_parser: Optional[BaseOutputParser] = None,
) -> BasePromptTemplate:
"""Create a prompt for the agent.
@ -116,10 +117,10 @@ class ConversationalChatAgent(Agent):
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[BaseMessage]:
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> list[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = []
thoughts: list[BaseMessage] = []
for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage(
@ -137,7 +138,7 @@ class ConversationalChatAgent(Agent):
output_parser: Optional[AgentOutputParser] = None,
system_message: str = PREFIX,
human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools.

View File

@ -1,10 +1,8 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction
def format_log_to_str(
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
observation_prefix: str = "Observation: ",
llm_prefix: str = "Thought: ",
) -> str:

View File

@ -1,13 +1,11 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
def format_log_to_messages(
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
template_tool_response: str = "{observation}",
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process.
Args:
@ -18,7 +16,7 @@ def format_log_to_messages(
Returns:
List[BaseMessage]: The scratchpad.
"""
thoughts: List[BaseMessage] = []
thoughts: list[BaseMessage] = []
for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage(

View File

@ -1,5 +1,5 @@
import json
from typing import List, Sequence, Tuple
from collections.abc import Sequence
from langchain_core.agents import AgentAction, AgentActionMessageLog
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
) -> list[BaseMessage]:
"""Convert an agent action to a message.
This code is used to reconstruct the original AI message from the agent action.
@ -54,8 +54,8 @@ def _create_function_message(
def format_to_openai_function_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
intermediate_steps: Sequence[tuple[AgentAction, str]],
) -> list[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
Args:

View File

@ -1,5 +1,5 @@
import json
from typing import List, Sequence, Tuple
from collections.abc import Sequence
from langchain_core.agents import AgentAction
from langchain_core.messages import (
@ -40,8 +40,8 @@ def _create_tool_message(
def format_to_tool_messages(
intermediate_steps: Sequence[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
intermediate_steps: Sequence[tuple[AgentAction, str]],
) -> list[BaseMessage]:
"""Convert (AgentAction, tool output) tuples into ToolMessages.
Args:

View File

@ -1,10 +1,8 @@
from typing import List, Tuple
from langchain_core.agents import AgentAction
def format_xml(
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
) -> str:
"""Format the intermediate steps as XML.

View File

@ -1,6 +1,7 @@
"""Load agent."""
from typing import Any, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager

View File

@ -1,4 +1,5 @@
from typing import List, Sequence, Union
from collections.abc import Sequence
from typing import Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
@ -15,7 +16,7 @@ def create_json_chat_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
stop_sequence: Union[bool, List[str]] = True,
stop_sequence: Union[bool, list[str]] = True,
tools_renderer: ToolsRenderer = render_text_description,
template_tool_response: str = TEMPLATE_TOOL_RESPONSE,
) -> Runnable:

View File

@ -3,7 +3,7 @@
import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union
from typing import Any, Optional, Union
import yaml
from langchain_core._api import deprecated
@ -20,7 +20,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
def _load_agent_from_tools(
config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any
config: dict, llm: BaseLanguageModel, tools: list[Tool], **kwargs: Any
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
config_type = config.pop("_type")
if config_type not in AGENT_TO_CLASS:
@ -35,7 +35,7 @@ def _load_agent_from_tools(
def load_agent_from_config(
config: dict,
llm: Optional[BaseLanguageModel] = None,
tools: Optional[List[Tool]] = None,
tools: Optional[list[Tool]] = None,
**kwargs: Any,
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
"""Load agent from Config Dict.
@ -130,7 +130,7 @@ def _load_agent_from_file(
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix[1:] == "yaml":
with open(file_path, "r") as f:
with open(file_path) as f:
config = yaml.safe_load(f)
else:
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Callable, NamedTuple, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager
@ -83,7 +84,7 @@ class ZeroShotAgent(Agent):
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
@ -118,7 +119,7 @@ class ZeroShotAgent(Agent):
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
input_variables: Optional[list[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools.
@ -183,7 +184,7 @@ class MRKLChain(AgentExecutor):
@classmethod
def from_chains(
cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any
cls, llm: BaseLanguageModel, chains: list[ChainConfig], **kwargs: Any
) -> AgentExecutor:
"""User-friendly way to initialize the MRKL chain.

View File

@ -2,18 +2,14 @@ from __future__ import annotations
import asyncio
import json
from collections.abc import Sequence
from json import JSONDecodeError
from time import sleep
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
@ -111,7 +107,7 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
def _is_assistants_builtin_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
) -> bool:
"""Determine if tool corresponds to OpenAI Assistants built-in."""
assistants_builtin_tools = ("code_interpreter", "file_search")
@ -123,8 +119,8 @@ def _is_assistants_builtin_tool(
def _get_assistants_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> Dict[str, Any]:
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
) -> dict[str, Any]:
"""Convert a raw function/class to an OpenAI tool.
Note that OpenAI assistants supports several built-in tools,
@ -137,14 +133,14 @@ def _get_assistants_tool(
OutputType = Union[
List[OpenAIAssistantAction],
list[OpenAIAssistantAction],
OpenAIAssistantFinish,
List["ThreadMessage"],
List["RequiredActionFunctionToolCall"],
list["ThreadMessage"],
list["RequiredActionFunctionToolCall"],
]
class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
"""Run an OpenAI Assistant.
Example using OpenAI tools:
@ -498,7 +494,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
return response
def _parse_intermediate_steps(
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
) -> dict:
last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
@ -652,7 +648,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
return run
async def _aparse_intermediate_steps(
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
) -> dict:
last_action, last_output = intermediate_steps[-1]
run = self._wait_for_run(last_action.run_id, last_action.thread_id)

View File

@ -1,6 +1,6 @@
"""Memory used to save agent output AND intermediate steps."""
from typing import Any, Dict, List
from typing import Any
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
@ -43,19 +43,19 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
format_as_tools: bool = False
@property
def buffer(self) -> List[BaseMessage]:
def buffer(self) -> list[BaseMessage]:
"""String buffer of memory."""
return self.chat_memory.messages
@property
def memory_variables(self) -> List[str]:
def memory_variables(self) -> list[str]:
"""Always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.
Args:
@ -74,7 +74,7 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
)
return {self.memory_key: final_buffer}
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
def save_context(self, inputs: dict[str, Any], outputs: dict[str, Any]) -> None:
"""Save context from this conversation to buffer. Pruned.
Args:

View File

@ -1,6 +1,7 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
from typing import Any, List, Optional, Sequence, Tuple, Type, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentFinish
@ -51,11 +52,11 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
output_parser: Type[OpenAIFunctionsAgentOutputParser] = (
output_parser: type[OpenAIFunctionsAgentOutputParser] = (
OpenAIFunctionsAgentOutputParser
)
def get_allowed_tools(self) -> List[str]:
def get_allowed_tools(self) -> list[str]:
"""Get allowed tools."""
return [t.name for t in self.tools]
@ -81,19 +82,19 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
return self
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Get input keys. Input refers to user input here."""
return ["input"]
@property
def functions(self) -> List[dict]:
def functions(self) -> list[dict]:
"""Get functions."""
return [dict(convert_to_openai_function(t)) for t in self.tools]
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
with_functions: bool = True,
**kwargs: Any,
@ -135,7 +136,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -168,7 +169,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -213,7 +214,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> ChatPromptTemplate:
"""Create prompt for this agent.
@ -227,7 +228,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
@ -248,7 +249,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),

View File

@ -1,8 +1,9 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
import json
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Union
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
@ -34,7 +35,7 @@ from langchain.agents.format_scratchpad.openai_functions import (
_FunctionsAgentAction = AgentActionMessageLog
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
@ -58,7 +59,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
f"the `arguments` JSON does not contain `actions` key."
)
final_tools: List[AgentAction] = []
final_tools: list[AgentAction] = []
for tool_schema in tools:
if "action" in tool_schema:
_tool_input = tool_schema["action"]
@ -112,7 +113,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
tools: Sequence[BaseTool]
prompt: BasePromptTemplate
def get_allowed_tools(self) -> List[str]:
def get_allowed_tools(self) -> list[str]:
"""Get allowed tools."""
return [t.name for t in self.tools]
@ -127,12 +128,12 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
return self
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Get input keys. Input refers to user input here."""
return ["input"]
@property
def functions(self) -> List[dict]:
def functions(self) -> list[dict]:
"""Get the functions for the agent."""
enum_vals = [t.name for t in self.tools]
tool_selection = {
@ -194,10 +195,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
@ -224,10 +225,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Async given input, decided what to do.
Args:
@ -258,7 +259,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
) -> BasePromptTemplate:
"""Create prompt for this agent.
@ -272,7 +273,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
A prompt template to pass into this agent.
"""
_prompts = extra_prompt_messages or []
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
if system_message:
messages = [system_message]
else:
@ -293,7 +294,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),

View File

@ -1,4 +1,5 @@
from typing import Optional, Sequence
from collections.abc import Sequence
from typing import Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate

View File

@ -1,6 +1,6 @@
import json
from json import JSONDecodeError
from typing import List, Union
from typing import Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
@ -77,7 +77,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
)
def parse_result(
self, result: List[Generation], *, partial: bool = False
self, result: list[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")

View File

@ -1,4 +1,4 @@
from typing import List, Union
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.messages import BaseMessage
@ -15,12 +15,12 @@ OpenAIToolAgentAction = ToolAgentAction
def parse_ai_message_to_openai_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
tool_actions = parse_ai_message_to_tool_action(message)
if isinstance(tool_actions, AgentFinish):
return tool_actions
final_actions: List[AgentAction] = []
final_actions: list[AgentAction] = []
for action in tool_actions:
if isinstance(action, ToolAgentAction):
final_actions.append(
@ -54,12 +54,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
return "openai-tools-agent-output-parser"
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
self, result: list[Generation], *, partial: bool = False
) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return parse_ai_message_to_openai_tool_action(message)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

View File

@ -1,6 +1,7 @@
import json
import re
from typing import Pattern, Union
from re import Pattern
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException

View File

@ -1,4 +1,5 @@
from typing import Sequence, Union
from collections.abc import Sequence
from typing import Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException

View File

@ -1,6 +1,6 @@
import json
from json import JSONDecodeError
from typing import List, Union
from typing import Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.exceptions import OutputParserException
@ -21,12 +21,12 @@ class ToolAgentAction(AgentActionMessageLog): # type: ignore[override]
def parse_ai_message_to_tool_action(
message: BaseMessage,
) -> Union[List[AgentAction], AgentFinish]:
) -> Union[list[AgentAction], AgentFinish]:
"""Parse an AI message potentially containing tool_calls."""
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message got {type(message)}")
actions: List = []
actions: list = []
if message.tool_calls:
tool_calls = message.tool_calls
else:
@ -91,12 +91,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
return "tools-agent-output-parser"
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[List[AgentAction], AgentFinish]:
self, result: list[Generation], *, partial: bool = False
) -> Union[list[AgentAction], AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return parse_ai_message_to_tool_action(message)
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
raise ValueError("Can only parse messages")

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Optional, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
@ -20,7 +21,7 @@ def create_react_agent(
output_parser: Optional[AgentOutputParser] = None,
tools_renderer: ToolsRenderer = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
stop_sequence: Union[bool, list[str]] = True,
) -> Runnable:
"""Create an agent that uses ReAct prompting.

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, List, Optional, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from langchain_core._api import deprecated
from langchain_core.documents import Document
@ -65,7 +66,7 @@ class ReActDocstoreAgent(Agent):
return "Observation: "
@property
def _stop(self) -> List[str]:
def _stop(self) -> list[str]:
return ["\nObservation:"]
@property
@ -122,7 +123,7 @@ class DocstoreExplorer:
return self._paragraphs[0]
@property
def _paragraphs(self) -> List[str]:
def _paragraphs(self) -> list[str]:
if self.document is None:
raise ValueError("Cannot get paragraphs without a document")
return self.document.page_content.split("\n\n")

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any
from langchain_core.agents import AgentAction
from langchain_core.prompts.chat import ChatPromptTemplate
@ -12,7 +12,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
return False
def _construct_agent_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
if len(intermediate_steps) == 0:
return ""
@ -26,7 +26,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
f"you return as final answer):\n{thoughts}"
)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
intermediate_steps = kwargs.pop("intermediate_steps")
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
intermediate_steps

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel

View File

@ -1,5 +1,6 @@
import re
from typing import Any, List, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Optional, Union
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction
@ -49,7 +50,7 @@ class StructuredChatAgent(Agent):
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
self, intermediate_steps: list[tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
@ -74,7 +75,7 @@ class StructuredChatAgent(Agent):
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
@property
def _stop(self) -> List[str]:
def _stop(self) -> list[str]:
return ["Observation:"]
@classmethod
@ -85,8 +86,8 @@ class StructuredChatAgent(Agent):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tool_strings = []
for tool in tools:
@ -117,8 +118,8 @@ class StructuredChatAgent(Agent):
suffix: str = SUFFIX,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
input_variables: Optional[list[str]] = None,
memory_prompts: Optional[list[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
@ -157,7 +158,7 @@ def create_structured_chat_agent(
prompt: ChatPromptTemplate,
tools_renderer: ToolsRenderer = render_text_description_and_args,
*,
stop_sequence: Union[bool, List[str]] = True,
stop_sequence: Union[bool, list[str]] = True,
) -> Runnable:
"""Create an agent aimed at supporting tools with multiple inputs.

View File

@ -3,7 +3,8 @@ from __future__ import annotations
import json
import logging
import re
from typing import Optional, Pattern, Union
from re import Pattern
from typing import Optional, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException

View File

@ -1,4 +1,5 @@
from typing import Callable, List, Sequence, Tuple
from collections.abc import Sequence
from typing import Callable
from langchain_core.agents import AgentAction
from langchain_core.language_models import BaseLanguageModel
@ -12,7 +13,7 @@ from langchain.agents.format_scratchpad.tools import (
)
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]]
MessageFormatter = Callable[[Sequence[tuple[AgentAction, str]]], list[BaseMessage]]
def create_tool_calling_agent(

View File

@ -1,6 +1,6 @@
"""Interface for tools."""
from typing import List, Optional
from typing import Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
@ -20,7 +20,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
def _run(
self,
requested_tool_name: str,
available_tool_names: List[str],
available_tool_names: list[str],
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
@ -33,7 +33,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
async def _arun(
self,
requested_tool_name: str,
available_tool_names: List[str],
available_tool_names: list[str],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""

View File

@ -1,4 +1,4 @@
from typing import Dict, Type, Union
from typing import Union
from langchain.agents.agent import BaseSingleActionAgent
from langchain.agents.agent_types import AgentType
@ -12,9 +12,9 @@ from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.structured_chat.base import StructuredChatAgent
AGENT_TYPE = Union[Type[BaseSingleActionAgent], Type[OpenAIMultiFunctionsAgent]]
AGENT_TYPE = Union[type[BaseSingleActionAgent], type[OpenAIMultiFunctionsAgent]]
AGENT_TO_CLASS: Dict[AgentType, AGENT_TYPE] = {
AGENT_TO_CLASS: dict[AgentType, AGENT_TYPE] = {
AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent,
AgentType.REACT_DOCSTORE: ReActDocstoreAgent,
AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent,

View File

@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence
from langchain_core.tools import BaseTool

View File

@ -1,4 +1,5 @@
from typing import Any, List, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Union
from langchain_core._api import deprecated
from langchain_core.agents import AgentAction, AgentFinish
@ -38,13 +39,13 @@ class XMLAgent(BaseSingleActionAgent):
"""
tools: List[BaseTool]
tools: list[BaseTool]
"""List of tools this agent has access to."""
llm_chain: LLMChain
"""Chain to use to predict action."""
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
return ["input"]
@staticmethod
@ -60,7 +61,7 @@ class XMLAgent(BaseSingleActionAgent):
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -84,7 +85,7 @@ class XMLAgent(BaseSingleActionAgent):
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
intermediate_steps: list[tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
@ -113,7 +114,7 @@ def create_xml_agent(
prompt: BasePromptTemplate,
tools_renderer: ToolsRenderer = render_text_description,
*,
stop_sequence: Union[bool, List[str]] = True,
stop_sequence: Union[bool, list[str]] = True,
) -> Runnable:
"""Create an agent that uses XML to format its logic.

View File

@ -1,7 +1,8 @@
from __future__ import annotations
import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from collections.abc import AsyncIterator
from typing import Any, Literal, Union, cast
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_core.outputs import LLMResult
@ -25,7 +26,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
self.done = asyncio.Event()
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.outputs import LLMResult
@ -30,7 +30,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
def __init__(
self,
*,
answer_prefix_tokens: Optional[List[str]] = None,
answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
) -> None:
@ -62,7 +62,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.answer_reached = False
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
# If two calls are made in a row, this resets the state
self.done.clear()

View File

@ -1,7 +1,7 @@
"""Callback Handler streams to stdout on new llm token."""
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.callbacks import StreamingStdOutCallbackHandler
@ -31,7 +31,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(
self,
*,
answer_prefix_tokens: Optional[List[str]] = None,
answer_prefix_tokens: Optional[list[str]] = None,
strip_tokens: bool = True,
stream_prefix: bool = False,
) -> None:
@ -63,7 +63,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.answer_reached = False
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.answer_reached = False

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
from urllib.parse import urlparse
from langchain_core._api import deprecated
@ -20,7 +21,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
def _extract_scheme_and_domain(url: str) -> tuple[str, str]:
"""Extract the scheme + domain from a given URL.
Args:
@ -215,7 +216,7 @@ try:
"""
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -223,7 +224,7 @@ try:
return [self.question_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Expect output key.
:meta private:
@ -243,7 +244,7 @@ try:
@model_validator(mode="before")
@classmethod
def validate_limit_to_domains(cls, values: Dict) -> Any:
def validate_limit_to_domains(cls, values: dict) -> Any:
"""Check that allowed domains are valid."""
# This check must be a pre=True check, so that a default of None
# won't be set to limit_to_domains if it's not provided.
@ -275,9 +276,9 @@ try:
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
@ -308,9 +309,9 @@ try:
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = (
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
)

View File

@ -1,12 +1,13 @@
"""Base interface that all chains should implement."""
import builtins
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union, cast
from typing import Any, Optional, Union, cast
import yaml
from langchain_core._api import deprecated
@ -46,7 +47,7 @@ def _get_verbosity() -> bool:
return get_verbose()
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
@ -86,13 +87,13 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
will be printed to the console. Defaults to the global `verbose` value,
accessible via `langchain.globals.get_verbose()`."""
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
"""Optional list of tags associated with the chain. Defaults to None.
These tags will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None.
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
@ -107,7 +108,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
@ -115,7 +116,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
@ -123,10 +124,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def invoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
@ -162,7 +163,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else self._call(inputs)
)
final_outputs: Dict[str, Any] = self.prep_outputs(
final_outputs: dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
@ -176,10 +177,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def ainvoke(
self,
input: Dict[str, Any],
input: dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
config = ensure_config(config)
callbacks = config.get("callbacks")
tags = config.get("tags")
@ -213,7 +214,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported
else await self._acall(inputs)
)
final_outputs: Dict[str, Any] = await self.aprep_outputs(
final_outputs: dict[str, Any] = await self.aprep_outputs(
inputs, outputs, return_only_outputs
)
except BaseException as e:
@ -231,7 +232,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@model_validator(mode="before")
@classmethod
def raise_callback_manager_deprecation(cls, values: Dict) -> Any:
def raise_callback_manager_deprecation(cls, values: dict) -> Any:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
if values.get("callbacks") is not None:
@ -261,15 +262,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@property
@abstractmethod
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Keys expected to be in the chain input."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Keys expected to be in the chain output."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
def _validate_inputs(self, inputs: dict[str, Any]) -> None:
"""Check that all inputs are present."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
@ -289,7 +290,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@ -297,9 +298,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Execute the chain.
This is a private method that is not user-facing. It is only called within
@ -319,9 +320,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Asynchronously execute the chain.
This is a private method that is not user-facing. It is only called within
@ -345,15 +346,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="invoke", removal="1.0")
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
inputs: Union[dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Execute the chain.
Args:
@ -396,15 +397,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
inputs: Union[dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Asynchronously execute the chain.
Args:
@ -445,10 +446,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
inputs: dict[str, str],
outputs: dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
@ -471,10 +472,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
async def aprep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
inputs: dict[str, str],
outputs: dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Validate and prepare chain outputs, and save info about this run to memory.
Args:
@ -495,7 +496,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
@ -519,7 +520,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
inputs = dict(inputs, **external_context)
return inputs
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
async def aprep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Prepare chain inputs, including adding inputs from memory.
Args:
@ -557,8 +558,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
@ -628,8 +629,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Convenience method for executing chain.
@ -695,7 +696,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
def dict(self, **kwargs: Any) -> dict:
"""Dictionary representation of chain.
Expects `Chain._chain_type` property to be implemented and for memory to be
@ -763,7 +764,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
@deprecated("0.1.0", alternative="batch", removal="1.0")
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None
) -> list[builtins.dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]

View File

@ -1,7 +1,7 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -47,22 +47,22 @@ class BaseCombineDocumentsChain(Chain, ABC):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
**{self.input_key: (list[Document], None)}, # type: ignore[call-overload]
)
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -70,14 +70,14 @@ class BaseCombineDocumentsChain(Chain, ABC):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list
@ -96,7 +96,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
return None
@abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
def combine_docs(self, docs: list[Document], **kwargs: Any) -> tuple[str, dict]:
"""Combine documents into a single string.
Args:
@ -111,8 +111,8 @@ class BaseCombineDocumentsChain(Chain, ABC):
@abstractmethod
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], **kwargs: Any
) -> tuple[str, dict]:
"""Combine documents into a single string.
Args:
@ -127,9 +127,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
def _call(
self,
inputs: Dict[str, List[Document]],
inputs: dict[str, list[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
@ -143,9 +143,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
async def _acall(
self,
inputs: Dict[str, List[Document]],
inputs: dict[str, list[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Prepare inputs, call combine docs, prepare outputs."""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
@ -229,7 +229,7 @@ class AnalyzeDocumentChain(Chain):
combine_docs_chain: BaseCombineDocumentsChain
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -237,7 +237,7 @@ class AnalyzeDocumentChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -246,7 +246,7 @@ class AnalyzeDocumentChain(Chain):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
@ -254,20 +254,20 @@ class AnalyzeDocumentChain(Chain):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config)
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Split document into chunks and pass to CombineDocumentsChain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document])
# Other keys are assumed to be needed for LLM prediction
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs
return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -113,20 +113,20 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
"intermediate_steps": (list[str], None),
}, # type: ignore[call-overload]
)
return super().get_output_schema(config)
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -143,7 +143,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_reduce_chain(cls, values: Dict) -> Any:
def get_reduce_chain(cls, values: dict) -> Any:
"""For backwards compatibility."""
if "combine_document_chain" in values:
if "reduce_documents_chain" in values:
@ -167,7 +167,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_return_intermediate_steps(cls, values: Dict) -> Any:
def get_return_intermediate_steps(cls, values: dict) -> Any:
"""For backwards compatibility."""
if "return_map_steps" in values:
values["return_intermediate_steps"] = values["return_map_steps"]
@ -176,7 +176,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any:
def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided."""
if "llm_chain" not in values:
raise ValueError("llm_chain must be provided")
@ -227,11 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.
@ -258,11 +258,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -79,7 +80,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""Key in output of llm_chain to rank on."""
answer_key: str
"""Key in output of llm_chain to return as answer."""
metadata_keys: Optional[List[str]] = None
metadata_keys: Optional[list[str]] = None
"""Additional metadata from the chosen document to return."""
return_intermediate_steps: bool = False
"""Return intermediate steps.
@ -92,19 +93,19 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema: Dict[str, Any] = {
) -> type[BaseModel]:
schema: dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
schema["intermediate_steps"] = (list[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})
return create_model("MapRerankOutput", **schema)
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -140,7 +141,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any:
def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided."""
if "llm_chain" not in values:
raise ValueError("llm_chain must be provided")
@ -163,8 +164,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
return values
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
@ -187,8 +188,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
return self._process_results(docs, results)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
@ -212,10 +213,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
def _process_results(
self,
docs: List[Document],
results: Sequence[Union[str, List[str], Dict[str, str]]],
) -> Tuple[str, dict]:
typed_results = cast(List[dict], results)
docs: list[Document],
results: Sequence[Union[str, list[str], dict[str, str]]],
) -> tuple[str, dict]:
typed_results = cast(list[dict], results)
sorted_res = sorted(
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple
from typing import Any, Callable, Optional, Protocol
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -15,20 +15,20 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
def __call__(self, docs: list[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
class AsyncCombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
async def __call__(self, docs: list[Document], **kwargs: Any) -> str:
"""Async interface for the combine_docs method."""
def split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> list[list[Document]]:
"""Split Documents into subsets that each meet a cumulative length constraint.
Args:
@ -59,7 +59,7 @@ def split_list_of_docs(
def collapse_docs(
docs: List[Document],
docs: list[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
@ -91,7 +91,7 @@ def collapse_docs(
async def acollapse_docs(
docs: List[Document],
docs: list[Document],
combine_document_func: AsyncCombineDocsProtocol,
**kwargs: Any,
) -> Document:
@ -229,11 +229,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Combine multiple documents recursively.
Args:
@ -258,11 +258,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
) -> tuple[str, dict]:
"""Async combine multiple documents recursively.
Args:
@ -287,16 +287,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def _collapse(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
) -> tuple[list[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
@ -322,16 +322,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def _acollapse(
self,
docs: List[Document],
docs: list[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
) -> tuple[list[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
async def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
return await self._collapse_chain.arun(
input_documents=docs, callbacks=callbacks, **kwargs
)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -98,7 +98,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
"""Return the results of the refine steps in the output."""
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -115,7 +115,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_return_intermediate_steps(cls, values: Dict) -> Any:
def get_return_intermediate_steps(cls, values: dict) -> Any:
"""For backwards compatibility."""
if "return_refine_steps" in values:
values["return_intermediate_steps"] = values["return_refine_steps"]
@ -124,7 +124,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any:
def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided."""
if "initial_llm_chain" not in values:
raise ValueError("initial_llm_chain must be provided")
@ -147,8 +147,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return values
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain.
Args:
@ -172,8 +172,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
return self._construct_result(refine_steps, res)
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Async combine by mapping a first chain over all, then stuffing
into a final chain.
@ -197,22 +197,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
refine_steps.append(res)
return self._construct_result(refine_steps, res)
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
def _construct_result(self, refine_steps: list[str], res: str) -> tuple[str, dict]:
if self.return_intermediate_steps:
extra_return_dict = {"intermediate_steps": refine_steps}
else:
extra_return_dict = {}
return res, extra_return_dict
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
def _construct_refine_inputs(self, doc: Document, res: str) -> dict[str, Any]:
return {
self.document_variable_name: format_document(doc, self.document_prompt),
self.initial_response_name: res,
}
def _construct_initial_inputs(
self, docs: List[Document], **kwargs: Any
) -> Dict[str, Any]:
self, docs: list[Document], **kwargs: Any
) -> dict[str, Any]:
base_info = {"page_content": docs[0].page_content}
base_info.update(docs[0].metadata)
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}

View File

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import Callbacks
@ -29,7 +29,7 @@ def create_stuff_documents_chain(
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
document_variable_name: str = DOCUMENTS_KEY,
) -> Runnable[Dict[str, Any], Any]:
) -> Runnable[dict[str, Any], Any]:
"""Create a chain for passing a list of Documents to a model.
Args:
@ -163,7 +163,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
@model_validator(mode="before")
@classmethod
def get_default_document_variable_name(cls, values: Dict) -> Any:
def get_default_document_variable_name(cls, values: dict) -> Any:
"""Get default document variable name, if not provided.
If only one variable is present in the llm_chain.prompt,
@ -188,13 +188,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return values
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
extra_keys = [
k for k in self.llm_chain.input_keys if k != self.document_variable_name
]
return super().input_keys + extra_keys
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
def _get_inputs(self, docs: list[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and then join all the documents together into one input with name
@ -220,7 +220,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list
@ -241,8 +241,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return self.llm_chain._get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
@ -259,8 +259,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
) -> tuple[str, dict]:
"""Async stuff all documents into one prompt and pass to LLM.
Args:

View File

@ -1,6 +1,6 @@
"""Chain for applying constitutional principles to the outputs of another chain."""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
@ -190,15 +190,15 @@ class ConstitutionalChain(Chain):
""" # noqa: E501
chain: LLMChain
constitutional_principles: List[ConstitutionalPrinciple]
constitutional_principles: list[ConstitutionalPrinciple]
critique_chain: LLMChain
revision_chain: LLMChain
return_intermediate_steps: bool = False
@classmethod
def get_principles(
cls, names: Optional[List[str]] = None
) -> List[ConstitutionalPrinciple]:
cls, names: Optional[list[str]] = None
) -> list[ConstitutionalPrinciple]:
if names is None:
return list(PRINCIPLES.values())
else:
@ -224,12 +224,12 @@ class ConstitutionalChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys."""
return self.chain.input_keys
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Output keys."""
if self.return_intermediate_steps:
return ["output", "critiques_and_revisions", "initial_output"]
@ -237,9 +237,9 @@ class ConstitutionalChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run(
**inputs,
@ -305,7 +305,7 @@ class ConstitutionalChain(Chain):
color="yellow",
)
final_output: Dict[str, Any] = {"output": response}
final_output: dict[str, Any] = {"output": response}
if self.return_intermediate_steps:
final_output["initial_output"] = initial_response
final_output["critiques_and_revisions"] = critiques_and_revisions

View File

@ -1,7 +1,5 @@
"""Chain that carries on a conversation and calls an LLM."""
from typing import List
from langchain_core._api import deprecated
from langchain_core.memory import BaseMemory
from langchain_core.prompts import BasePromptTemplate
@ -121,7 +119,7 @@ class ConversationChain(LLMChain): # type: ignore[override, override]
return False
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Use this since so some prompt vars come from history."""
return [self.input_key]

View File

@ -6,7 +6,7 @@ import inspect
import warnings
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -32,13 +32,13 @@ from langchain.chains.question_answering import load_qa_chain
# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
CHAT_TURN_TYPE = Union[tuple[str, str], BaseMessage]
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
buffer = ""
for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage):
@ -64,7 +64,7 @@ class InputType(BaseModel):
question: str
"""The question to answer."""
chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list)
chat_history: list[CHAT_TURN_TYPE] = Field(default_factory=list)
"""The chat history to use for retrieval."""
@ -89,7 +89,7 @@ class BaseConversationalRetrievalChain(Chain):
"""Return the retrieved source documents as part of the final result."""
return_generated_question: bool = False
"""Return the generated question as part of the final result."""
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
get_chat_history: Optional[Callable[[list[CHAT_TURN_TYPE]], str]] = None
"""An optional function to get a string of the chat history.
If None is provided, will use a default."""
response_if_no_docs_found: Optional[str] = None
@ -103,17 +103,17 @@ class BaseConversationalRetrievalChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys."""
return ["question", "chat_history"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
) -> type[BaseModel]:
return InputType
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return the output keys.
:meta private:
@ -129,17 +129,17 @@ class BaseConversationalRetrievalChain(Chain):
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
@ -159,7 +159,7 @@ class BaseConversationalRetrievalChain(Chain):
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {}
output: dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
@ -182,17 +182,17 @@ class BaseConversationalRetrievalChain(Chain):
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
@ -212,7 +212,7 @@ class BaseConversationalRetrievalChain(Chain):
else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {}
output: dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
@ -368,7 +368,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""If set, enforces that the documents returned are less than this limit.
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs)
if self.max_tokens_limit and isinstance(
@ -388,10 +388,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
@ -401,10 +401,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
@ -420,7 +420,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
chain_type: str = "stuff",
verbose: bool = False,
condense_question_llm: Optional[BaseLanguageModel] = None,
combine_docs_chain_kwargs: Optional[Dict] = None,
combine_docs_chain_kwargs: Optional[dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
@ -485,7 +485,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
warnings.warn(
"`ChatVectorDBChain` is deprecated - "
"please use `from langchain.chains import ConversationalRetrievalChain`"
@ -495,10 +495,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
@ -509,10 +509,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
raise NotImplementedError("ChatVectorDBChain does not support async")
@ -523,7 +523,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
chain_type: str = "stuff",
combine_docs_chain_kwargs: Optional[Dict] = None,
combine_docs_chain_kwargs: Optional[dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
@ -44,8 +44,8 @@ class ElasticsearchDatabaseChain(Chain):
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
top_k: int = 10
"""Number of results to return from the query"""
ignore_indices: Optional[List[str]] = None
include_indices: Optional[List[str]] = None
ignore_indices: Optional[list[str]] = None
include_indices: Optional[list[str]] = None
input_key: str = "question" #: :meta private:
output_key: str = "result" #: :meta private:
sample_documents_in_index_info: int = 3
@ -66,7 +66,7 @@ class ElasticsearchDatabaseChain(Chain):
return self
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the singular input key.
:meta private:
@ -74,7 +74,7 @@ class ElasticsearchDatabaseChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return the singular output key.
:meta private:
@ -84,7 +84,7 @@ class ElasticsearchDatabaseChain(Chain):
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> List[str]:
def _list_indices(self) -> list[str]:
all_indices = [
index["index"] for index in self.database.cat.indices(format="json")
]
@ -96,7 +96,7 @@ class ElasticsearchDatabaseChain(Chain):
return all_indices
def _get_indices_infos(self, indices: List[str]) -> str:
def _get_indices_infos(self, indices: list[str]) -> str:
mappings = self.database.indices.get_mapping(index=",".join(indices))
if self.sample_documents_in_index_info > 0:
for k, v in mappings.items():
@ -114,15 +114,15 @@ class ElasticsearchDatabaseChain(Chain):
]
)
def _search(self, indices: List[str], query: str) -> str:
def _search(self, indices: list[str], query: str) -> str:
result = self.database.search(index=",".join(indices), body=query)
return str(result)
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\nESQuery:"
_run_manager.on_text(input_text, verbose=self.verbose)
@ -134,7 +134,7 @@ class ElasticsearchDatabaseChain(Chain):
"indices_info": indices_info,
"stop": ["\nESResult:"],
}
intermediate_steps: List = []
intermediate_steps: list = []
try:
intermediate_steps.append(query_inputs) # input: es generation
es_cmd = self.query_chain.invoke(
@ -163,7 +163,7 @@ class ElasticsearchDatabaseChain(Chain):
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
chain_result: dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

View File

@ -1,5 +1,3 @@
from typing import List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.few_shot import FewShotPromptTemplate
@ -9,7 +7,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
def generate_example(
examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
examples: list[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
) -> str:
"""Return another example given a list of examples for a prompt."""
prompt = FewShotPromptTemplate(

View File

@ -2,7 +2,8 @@ from __future__ import annotations
import logging
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.callbacks import (
CallbackManagerForChainRun,
@ -26,7 +27,7 @@ from langchain.chains.llm import LLMChain
logger = logging.getLogger(__name__)
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
def _extract_tokens_and_log_probs(response: AIMessage) -> tuple[list[str], list[float]]:
"""Extract tokens and log probabilities from chat model response."""
tokens = []
log_probs = []
@ -47,7 +48,7 @@ class QuestionGeneratorChain(LLMChain):
return False
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys for the chain."""
return ["user_input", "context", "response"]
@ -58,7 +59,7 @@ def _low_confidence_spans(
min_prob: float,
min_token_gap: int,
num_pad_tokens: int,
) -> List[str]:
) -> list[str]:
try:
import numpy as np
@ -117,22 +118,22 @@ class FlareChain(Chain):
"""Whether to start with retrieval."""
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys for the chain."""
return ["user_input"]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Output keys for the chain."""
return ["response"]
def _do_generation(
self,
questions: List[str],
questions: list[str],
user_input: str,
response: str,
_run_manager: CallbackManagerForChainRun,
) -> Tuple[str, bool]:
) -> tuple[str, bool]:
callbacks = _run_manager.get_child()
docs = []
for question in questions:
@ -153,12 +154,12 @@ class FlareChain(Chain):
def _do_retrieval(
self,
low_confidence_spans: List[str],
low_confidence_spans: list[str],
_run_manager: CallbackManagerForChainRun,
user_input: str,
response: str,
initial_response: str,
) -> Tuple[str, bool]:
) -> tuple[str, bool]:
question_gen_inputs = [
{
"user_input": user_input,
@ -187,9 +188,9 @@ class FlareChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
user_input = inputs[self.input_keys[0]]

View File

@ -1,16 +1,14 @@
from typing import Tuple
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
class FinishedOutputParser(BaseOutputParser[tuple[str, bool]]):
"""Output parser that checks if the output is finished."""
finished_value: str = "FINISHED"
"""Value that indicates the output is finished."""
def parse(self, text: str) -> Tuple[str, bool]:
def parse(self, text: str) -> tuple[str, bool]:
cleaned = text.strip()
finished = self.finished_value in cleaned
return cleaned.replace(self.finished_value, ""), finished

View File

@ -6,7 +6,7 @@ https://arxiv.org/abs/2212.10496
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
@ -38,23 +38,23 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys for Hyde's LLM chain."""
return self.llm_chain.input_schema.model_json_schema()["required"]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Output keys for Hyde's LLM chain."""
if isinstance(self.llm_chain, LLMChain):
return self.llm_chain.output_keys
else:
return ["text"]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Call the base embeddings."""
return self.base_embeddings.embed_documents(texts)
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
def combine_embeddings(self, embeddings: list[list[float]]) -> list[float]:
"""Combine embeddings into final embeddings."""
try:
import numpy as np
@ -73,7 +73,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
num_vectors = len(embeddings)
return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)]
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Generate a hypothetical document and embedded it."""
var_name = self.input_keys[0]
result = self.llm_chain.invoke({var_name: text})
@ -86,9 +86,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Call the internal llm chain."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain.invoke(

View File

@ -3,7 +3,8 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -100,7 +101,7 @@ class LLMChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Will be whatever keys the prompt expects.
:meta private:
@ -108,7 +109,7 @@ class LLMChain(Chain):
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Will always return text key.
:meta private:
@ -120,15 +121,15 @@ class LLMChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
@ -143,9 +144,9 @@ class LLMChain(Chain):
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks}
cast(list, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
generations: list[list[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
@ -155,7 +156,7 @@ class LLMChain(Chain):
async def agenerate(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
@ -170,9 +171,9 @@ class LLMChain(Chain):
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks}
cast(list, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
generations: list[list[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
@ -182,9 +183,9 @@ class LLMChain(Chain):
def prep_prompts(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
) -> tuple[list[PromptValue], Optional[list[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
@ -208,9 +209,9 @@ class LLMChain(Chain):
async def aprep_prompts(
self,
input_list: List[Dict[str, Any]],
input_list: list[dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
) -> tuple[list[PromptValue], Optional[list[str]]]:
"""Prepare prompts from inputs."""
stop = None
if len(input_list) == 0:
@ -233,8 +234,8 @@ class LLMChain(Chain):
return prompts, stop
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> list[dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
@ -254,8 +255,8 @@ class LLMChain(Chain):
return outputs
async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> list[dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
@ -278,7 +279,7 @@ class LLMChain(Chain):
def _run_output_key(self) -> str:
return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
def create_outputs(self, llm_result: LLMResult) -> list[dict[str, Any]]:
"""Create outputs from response."""
result = [
# Get the text of the top generated string.
@ -294,9 +295,9 @@ class LLMChain(Chain):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
@ -336,7 +337,7 @@ class LLMChain(Chain):
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
) -> Union[str, list[str], dict[str, Any]]:
"""Call predict and then parse the results."""
warnings.warn(
"The predict_and_parse method is deprecated, "
@ -350,7 +351,7 @@ class LLMChain(Chain):
async def apredict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
) -> Union[str, list[str], dict[str, str]]:
"""Call apredict and then parse the results."""
warnings.warn(
"The apredict_and_parse method is deprecated, "
@ -363,8 +364,8 @@ class LLMChain(Chain):
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, list[str], dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The apply_and_parse method is deprecated, "
@ -374,8 +375,8 @@ class LLMChain(Chain):
return self._parse_generation(result)
def _parse_generation(
self, generation: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
self, generation: list[dict[str, str]]
) -> Sequence[Union[str, list[str], dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key])
@ -385,8 +386,8 @@ class LLMChain(Chain):
return generation
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, list[str], dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The aapply_and_parse method is deprecated, "

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
@ -107,7 +107,7 @@ class LLMCheckerChain(Chain):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
@ -135,7 +135,7 @@ class LLMCheckerChain(Chain):
return values
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the singular input key.
:meta private:
@ -143,7 +143,7 @@ class LLMCheckerChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return the singular output key.
:meta private:
@ -152,9 +152,9 @@ class LLMCheckerChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import math
import re
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -163,7 +163,7 @@ class LLMMathChain(Chain):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
try:
import numexpr # noqa: F401
except ImportError:
@ -183,7 +183,7 @@ class LLMMathChain(Chain):
return values
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -191,7 +191,7 @@ class LLMMathChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Expect output key.
:meta private:
@ -221,7 +221,7 @@ class LLMMathChain(Chain):
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
) -> dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
@ -243,7 +243,7 @@ class LLMMathChain(Chain):
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
) -> dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
@ -263,9 +263,9 @@ class LLMMathChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
@ -277,9 +277,9 @@ class LLMMathChain(Chain):
async def _acall(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
@ -112,7 +112,7 @@ class LLMSummarizationCheckerChain(Chain):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
@ -131,7 +131,7 @@ class LLMSummarizationCheckerChain(Chain):
return values
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return the singular input key.
:meta private:
@ -139,7 +139,7 @@ class LLMSummarizationCheckerChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return the singular output key.
:meta private:
@ -148,9 +148,9 @@ class LLMSummarizationCheckerChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
all_true = False
count = 0

View File

@ -702,7 +702,7 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
with open(file_path) as f:
config = json.load(f)
elif file_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "r") as f:
with open(file_path) as f:
config = yaml.safe_load(f)
else:
raise ValueError("File type must be json or yaml")

View File

@ -6,7 +6,8 @@ then combines the results with another one.
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from collections.abc import Mapping
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
@ -84,7 +85,7 @@ class MapReduceChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -92,7 +93,7 @@ class MapReduceChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -101,15 +102,15 @@ class MapReduceChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Split the larger text into smaller chunks.
doc_text = inputs.pop(self.input_key)
texts = self.text_splitter.split_text(doc_text)
docs = [Document(page_content=text) for text in texts]
_inputs: Dict[str, Any] = {
_inputs: dict[str, Any] = {
**inputs,
self.combine_documents_chain.input_key: docs,
}

View File

@ -1,6 +1,6 @@
"""Pass input through a moderation endpoint."""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -42,7 +42,7 @@ class OpenAIModerationChain(Chain):
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict) -> Any:
def validate_environment(cls, values: dict) -> Any:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
@ -78,7 +78,7 @@ class OpenAIModerationChain(Chain):
return values
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -86,7 +86,7 @@ class OpenAIModerationChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -108,9 +108,9 @@ class OpenAIModerationChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
text = inputs[self.input_key]
if self.openai_pre_1_0:
results = self.client.create(text)
@ -122,9 +122,9 @@ class OpenAIModerationChain(Chain):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if self.openai_pre_1_0:
return await super()._acall(inputs, run_manager=run_manager)
text = inputs[self.input_key]

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.caches import BaseCache as BaseCache
@ -68,7 +68,7 @@ class NatBotChain(Chain):
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
if "llm" in values:
warnings.warn(
"Directly instantiating an NatBotChain with an llm is deprecated. "
@ -97,7 +97,7 @@ class NatBotChain(Chain):
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect url and browser content.
:meta private:
@ -105,7 +105,7 @@ class NatBotChain(Chain):
return [self.input_url_key, self.input_browser_content_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return command.
:meta private:
@ -114,9 +114,9 @@ class NatBotChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key]

View File

@ -1,12 +1,10 @@
"""Methods for creating chains that use OpenAI function-calling APIs."""
from collections.abc import Sequence
from typing import (
Any,
Callable,
Dict,
Optional,
Sequence,
Type,
Union,
)
@ -45,7 +43,7 @@ __all__ = [
@deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable")
def create_openai_fn_chain(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,
@ -128,7 +126,7 @@ def create_openai_fn_chain(
raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions]
output_parser = output_parser or get_openai_output_parser(functions)
llm_kwargs: Dict[str, Any] = {
llm_kwargs: dict[str, Any] = {
"functions": openai_functions,
}
if len(openai_functions) == 1 and enforce_single_function_usage:
@ -148,7 +146,7 @@ def create_openai_fn_chain(
since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output"
)
def create_structured_output_chain(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
output_schema: Union[dict[str, Any], type[BaseModel]],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
*,

View File

@ -1,4 +1,4 @@
from typing import Iterator, List
from collections.abc import Iterator
from langchain_core._api import deprecated
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
@ -21,7 +21,7 @@ class FactWithEvidence(BaseModel):
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
substring_quote: list[str] = Field(
...,
description=(
"Each source should be a direct quote from the context, "
@ -54,7 +54,7 @@ class QuestionAnswer(BaseModel):
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
answer: list[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
@ -83,7 +83,7 @@ def create_extraction_chain(
schema: dict,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
tags: Optional[List[str]] = None,
tags: Optional[list[str]] = None,
verbose: bool = False,
) -> Chain:
"""Creates a chain that extracts information from a passage.
@ -170,7 +170,7 @@ def create_extraction_chain_pydantic(
"""
class PydanticSchema(BaseModel):
info: List[pydantic_schema] # type: ignore
info: list[pydantic_schema] # type: ignore
if hasattr(pydantic_schema, "model_json_schema"):
openai_schema = pydantic_schema.model_json_schema()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import re
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import requests
from langchain_core._api import deprecated
@ -70,7 +70,7 @@ def _format_url(url: str, path_params: dict) -> str:
return url.format(**new_params)
def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -> dict:
def _openapi_params_to_json_schema(params: list[Parameter], spec: OpenAPISpec) -> dict:
properties = {}
required = []
for p in params:
@ -89,7 +89,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
def openapi_spec_to_openai_fn(
spec: OpenAPISpec,
) -> Tuple[List[Dict[str, Any]], Callable]:
) -> tuple[list[dict[str, Any]], Callable]:
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
functions.
@ -208,18 +208,18 @@ class SimpleRequestChain(Chain):
"""Key to use for the input of the request."""
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run the logic of this chain and return the output."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
name = inputs[self.input_key].pop("name")
@ -257,10 +257,10 @@ def get_openapi_chain(
llm: Optional[BaseLanguageModel] = None,
prompt: Optional[BasePromptTemplate] = None,
request_chain: Optional[Chain] = None,
llm_chain_kwargs: Optional[Dict] = None,
llm_chain_kwargs: Optional[dict] = None,
verbose: bool = False,
headers: Optional[Dict] = None,
params: Optional[Dict] = None,
headers: Optional[dict] = None,
params: Optional[dict] = None,
**kwargs: Any,
) -> SequentialChain:
"""Create a chain for querying an API from a OpenAPI spec.

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union, cast
from typing import Any, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
@ -21,7 +21,7 @@ class AnswerWithSources(BaseModel):
"""An answer to the question, with sources."""
answer: str = Field(..., description="Answer to the question that was asked")
sources: List[str] = Field(
sources: list[str] = Field(
..., description="List of sources used to answer the question"
)
@ -37,7 +37,7 @@ class AnswerWithSources(BaseModel):
)
def create_qa_with_structure_chain(
llm: BaseLanguageModel,
schema: Union[dict, Type[BaseModel]],
schema: Union[dict, type[BaseModel]],
output_parser: str = "base",
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
verbose: bool = False,

View File

@ -1,7 +1,7 @@
from typing import Any, Dict
from typing import Any
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
def _resolve_schema_references(schema: Any, definitions: dict[str, Any]) -> Any:
"""
Resolve the $ref keys in a JSON schema object using the provided definitions.
"""

View File

@ -1,4 +1,4 @@
from typing import List, Type, Union
from typing import Union
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
@ -51,7 +51,7 @@ If a property is not present and is not required in the function parameters, do
),
)
def create_extraction_chain_pydantic(
pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]],
pydantic_schemas: Union[list[type[BaseModel]], type[BaseModel]],
llm: BaseLanguageModel,
system_message: str = _EXTRACTION_TEMPLATE,
) -> Runnable:

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from typing import Callable
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.chat_models import BaseChatModel
@ -21,8 +21,8 @@ class ConditionalPromptSelector(BasePromptSelector):
default_prompt: BasePromptTemplate
"""Default prompt to use if no conditionals match."""
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
conditionals: list[
tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)
"""List of conditionals and prompts to use if the conditionals match."""

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import CallbackManagerForChainRun
@ -103,18 +103,18 @@ class QAGenerationChain(Chain):
raise NotImplementedError
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]:
) -> dict[str, list]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate(
[{"text": d.page_content} for d in docs], run_manager=run_manager

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -103,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -111,7 +111,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
return [self.question_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -123,13 +123,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
@model_validator(mode="before")
@classmethod
def validate_naming(cls, values: Dict) -> Any:
def validate_naming(cls, values: dict) -> Any:
"""Fix backwards compatibility in naming."""
if "combine_document_chain" in values:
values["combine_documents_chain"] = values.pop("combine_document_chain")
return values
def _split_sources(self, answer: str) -> Tuple[str, str]:
def _split_sources(self, answer: str) -> tuple[str, str]:
"""Split sources from answer."""
if re.search(r"SOURCES?:", answer, re.IGNORECASE):
answer, sources = re.split(
@ -143,17 +143,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
@abstractmethod
def _get_docs(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs to run questioning over."""
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
@ -167,7 +167,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
result: dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
@ -178,17 +178,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
@abstractmethod
async def _aget_docs(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs to run questioning over."""
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
@ -201,7 +201,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
answer, sources = self._split_sources(answer)
result: Dict[str, Any] = {
result: dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
@ -225,7 +225,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
input_docs_key: str = "docs" #: :meta private:
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -234,19 +234,19 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
def _get_docs(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs to run questioning over."""
return inputs.pop(self.input_docs_key)
async def _aget_docs(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs to run questioning over."""
return inputs.pop(self.input_docs_key)

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Mapping, Optional, Protocol
from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel

View File

@ -1,6 +1,6 @@
"""Question-answering with sources over an index."""
from typing import Any, Dict, List
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -25,7 +25,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
"""Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance(
@ -43,8 +43,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> list[Document]:
question = inputs[self.question_key]
docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
@ -52,8 +52,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> list[Document]:
question = inputs[self.question_key]
docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}

View File

@ -1,7 +1,7 @@
"""Question-answering with sources over a vector database."""
import warnings
from typing import Any, Dict, List
from typing import Any
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -27,10 +27,10 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
max_tokens_limit: int = 3375
"""Restrict the docs to return from store based on tokens,
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
search_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
num_docs = len(docs)
if self.reduce_k_below_max_tokens and isinstance(
@ -48,8 +48,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> list[Document]:
question = inputs[self.question_key]
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
@ -57,13 +57,13 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> list[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
warnings.warn(
"`VectorDBQAWithSourcesChain` is deprecated - "
"please use `from langchain.chains import RetrievalQAWithSourcesChain`"

View File

@ -3,7 +3,8 @@
from __future__ import annotations
import json
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.exceptions import OutputParserException
@ -172,7 +173,7 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
def construct_examples(input_output_pairs: Sequence[tuple[str, dict]]) -> list[dict]:
"""Construct examples from input-output pairs.
Args:
@ -267,7 +268,7 @@ def load_query_constructor_chain(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]],
examples: Optional[List] = None,
examples: Optional[list] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,

View File

@ -1,6 +1,7 @@
import datetime
import warnings
from typing import Any, Literal, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Any, Literal, Optional, Union
from langchain_core.utils import check_package_version
from typing_extensions import TypedDict

View File

@ -1,6 +1,7 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core._api import deprecated
from langchain_core.callbacks import BaseCallbackManager, Callbacks

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Union
from typing import Any, Union
from langchain_core.retrievers import (
BaseRetriever,
@ -11,7 +11,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
def create_retrieval_chain(
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
combine_docs_chain: Runnable[Dict[str, Any], str],
combine_docs_chain: Runnable[dict[str, Any], str],
) -> Runnable:
"""Create retrieval chain that retrieves documents and then passes them on.

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -54,7 +54,7 @@ class BaseRetrievalQA(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Input keys.
:meta private:
@ -62,7 +62,7 @@ class BaseRetrievalQA(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Output keys.
:meta private:
@ -123,14 +123,14 @@ class BaseRetrievalQA(Chain):
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get documents to do question answering over."""
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
@ -166,14 +166,14 @@ class BaseRetrievalQA(Chain):
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get documents to do question answering over."""
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
@ -266,7 +266,7 @@ class RetrievalQA(BaseRetrievalQA):
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
return self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
@ -277,7 +277,7 @@ class RetrievalQA(BaseRetrievalQA):
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
return await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
@ -307,12 +307,12 @@ class VectorDBQA(BaseRetrievalQA):
"""Number of documents to query for."""
search_type: str = "similarity"
"""Search type to use over vectorstore. `similarity` or `mmr`."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
search_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
@model_validator(mode="before")
@classmethod
def raise_deprecation(cls, values: Dict) -> Any:
def raise_deprecation(cls, values: dict) -> Any:
warnings.warn(
"`VectorDBQA` is deprecated - "
"please use `from langchain.chains import RetrievalQA`"
@ -321,7 +321,7 @@ class VectorDBQA(BaseRetrievalQA):
@model_validator(mode="before")
@classmethod
def validate_search_type(cls, values: Dict) -> Any:
def validate_search_type(cls, values: dict) -> Any:
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
@ -334,7 +334,7 @@ class VectorDBQA(BaseRetrievalQA):
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
@ -353,7 +353,7 @@ class VectorDBQA(BaseRetrievalQA):
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
) -> list[Document]:
"""Get docs."""
raise NotImplementedError("VectorDBQA does not support async")

View File

@ -3,7 +3,8 @@
from __future__ import annotations
from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from collections.abc import Mapping
from typing import Any, NamedTuple, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -17,17 +18,17 @@ from langchain.chains.base import Chain
class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
next_inputs: dict[str, Any]
class RouterChain(Chain, ABC):
"""Chain that outputs the name of a destination chain and the inputs to it."""
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
return ["destination", "next_inputs"]
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
"""
Route inputs to a destination chain.
@ -42,7 +43,7 @@ class RouterChain(Chain, ABC):
return Route(result["destination"], result["next_inputs"])
async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
self, inputs: dict[str, Any], callbacks: Callbacks = None
) -> Route:
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
@ -67,7 +68,7 @@ class MultiRouteChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Will be whatever keys the router chain prompt expects.
:meta private:
@ -75,7 +76,7 @@ class MultiRouteChain(Chain):
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Will always return text key.
:meta private:
@ -84,9 +85,9 @@ class MultiRouteChain(Chain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = self.router_chain.route(inputs, callbacks=callbacks)
@ -109,9 +110,9 @@ class MultiRouteChain(Chain):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from collections.abc import Sequence
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -18,7 +19,7 @@ class EmbeddingRouterChain(RouterChain):
"""Chain that uses embeddings to route between options."""
vectorstore: VectorStore
routing_keys: List[str] = ["query"]
routing_keys: list[str] = ["query"]
model_config = ConfigDict(
arbitrary_types_allowed=True,
@ -26,7 +27,7 @@ class EmbeddingRouterChain(RouterChain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
@ -35,18 +36,18 @@ class EmbeddingRouterChain(RouterChain):
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_input = ", ".join([inputs[k] for k in self.routing_keys])
results = await self.vectorstore.asimilarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@ -54,8 +55,8 @@ class EmbeddingRouterChain(RouterChain):
@classmethod
def from_names_and_descriptions(
cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore],
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:
@ -72,8 +73,8 @@ class EmbeddingRouterChain(RouterChain):
@classmethod
async def afrom_names_and_descriptions(
cls,
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
vectorstore_cls: Type[VectorStore],
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
vectorstore_cls: type[VectorStore],
embeddings: Embeddings,
**kwargs: Any,
) -> EmbeddingRouterChain:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, cast
from typing import Any, Optional, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -114,42 +114,42 @@ class LLMRouterChain(RouterChain):
return self
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Will be whatever keys the LLM chain prompt expects.
:meta private:
"""
return self.llm_chain.input_keys
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
super()._validate_outputs(outputs)
if not isinstance(outputs["next_inputs"], dict):
raise ValueError
def _call(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
output = cast(
Dict[str, Any],
dict[str, Any],
self.llm_chain.prompt.output_parser.parse(prediction),
)
return output
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
output = cast(
Dict[str, Any],
dict[str, Any],
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
)
return output
@ -163,14 +163,14 @@ class LLMRouterChain(RouterChain):
return cls(llm_chain=llm_chain, **kwargs)
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
class RouterOutputParser(BaseOutputParser[dict[str, str]]):
"""Parser for output of router chain in the multi-prompt chain."""
default_destination: str = "DEFAULT"
next_inputs_type: Type = str
next_inputs_type: type = str
next_inputs_inner_key: str = "input"
def parse(self, text: str) -> Dict[str, Any]:
def parse(self, text: str) -> dict[str, Any]:
try:
expected_keys = ["destination", "next_inputs"]
parsed = parse_and_check_json_markdown(text, expected_keys)

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
@ -142,14 +142,14 @@ class MultiPromptChain(MultiRouteChain):
""" # noqa: E501
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
return ["text"]
@classmethod
def from_prompts(
cls,
llm: BaseLanguageModel,
prompt_infos: List[Dict[str, str]],
prompt_infos: list[dict[str, str]],
default_chain: Optional[Chain] = None,
**kwargs: Any,
) -> MultiPromptChain:

View File

@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from collections.abc import Mapping
from typing import Any, Optional
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
@ -31,14 +32,14 @@ class MultiRetrievalQAChain(MultiRouteChain): # type: ignore[override]
"""Default chain to use when router doesn't map input to one of the destinations."""
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
return ["result"]
@classmethod
def from_retrievers(
cls,
llm: BaseLanguageModel,
retriever_infos: List[Dict[str, Any]],
retriever_infos: list[dict[str, Any]],
default_retriever: Optional[BaseRetriever] = None,
default_prompt: Optional[PromptTemplate] = None,
default_chain: Optional[Chain] = None,

View File

@ -1,6 +1,6 @@
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -16,9 +16,9 @@ from langchain.chains.base import Chain
class SequentialChain(Chain):
"""Chain where the outputs of one chain feed directly into next."""
chains: List[Chain]
input_variables: List[str]
output_variables: List[str] #: :meta private:
chains: list[Chain]
input_variables: list[str]
output_variables: list[str] #: :meta private:
return_all: bool = False
model_config = ConfigDict(
@ -27,7 +27,7 @@ class SequentialChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Return expected input keys to the chain.
:meta private:
@ -35,7 +35,7 @@ class SequentialChain(Chain):
return self.input_variables
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -44,7 +44,7 @@ class SequentialChain(Chain):
@model_validator(mode="before")
@classmethod
def validate_chains(cls, values: Dict) -> Any:
def validate_chains(cls, values: dict) -> Any:
"""Validate that the correct inputs exist for all chains."""
chains = values["chains"]
input_variables = values["input_variables"]
@ -97,9 +97,9 @@ class SequentialChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
known_values = inputs.copy()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
for i, chain in enumerate(self.chains):
@ -110,9 +110,9 @@ class SequentialChain(Chain):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
known_values = inputs.copy()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
@ -127,7 +127,7 @@ class SequentialChain(Chain):
class SimpleSequentialChain(Chain):
"""Simple chain where the outputs of one step feed directly into next."""
chains: List[Chain]
chains: list[Chain]
strip_outputs: bool = False
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
@ -138,7 +138,7 @@ class SimpleSequentialChain(Chain):
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input key.
:meta private:
@ -146,7 +146,7 @@ class SimpleSequentialChain(Chain):
return [self.input_key]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output key.
:meta private:
@ -171,9 +171,9 @@ class SimpleSequentialChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
@ -190,9 +190,9 @@ class SimpleSequentialChain(Chain):
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
@ -27,7 +27,7 @@ class SQLInputWithTables(TypedDict):
"""Input for a SQL Chain."""
question: str
table_names_to_use: List[str]
table_names_to_use: list[str]
def create_sql_query_chain(
@ -35,7 +35,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
"""Create a chain that generates SQL queries.
*Security Note*: This chain generates SQL queries for the given database.

View File

@ -1,5 +1,6 @@
import json
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
from collections.abc import Sequence
from typing import Any, Callable, Literal, Optional, Union
from langchain_core._api import deprecated
from langchain_core.output_parsers import (
@ -63,7 +64,7 @@ from pydantic import BaseModel
),
)
def create_openai_fn_runnable(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
llm: Runnable,
prompt: Optional[BasePromptTemplate] = None,
*,
@ -135,7 +136,7 @@ def create_openai_fn_runnable(
if not functions:
raise ValueError("Need to pass in at least one function. Received zero.")
openai_functions = [convert_to_openai_function(f) for f in functions]
llm_kwargs_: Dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
llm_kwargs_: dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
if len(openai_functions) == 1 and enforce_single_function_usage:
llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]}
output_parser = output_parser or get_openai_output_parser(functions)
@ -181,7 +182,7 @@ def create_openai_fn_runnable(
),
)
def create_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable,
prompt: Optional[BasePromptTemplate] = None,
*,
@ -437,7 +438,7 @@ def create_structured_output_runnable(
def _create_openai_tools_runnable(
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
tool: Union[dict[str, Any], type[BaseModel], Callable],
llm: Runnable,
*,
prompt: Optional[BasePromptTemplate],
@ -446,7 +447,7 @@ def _create_openai_tools_runnable(
first_tool_only: bool,
) -> Runnable:
oai_tool = convert_to_openai_tool(tool)
llm_kwargs: Dict[str, Any] = {"tools": [oai_tool]}
llm_kwargs: dict[str, Any] = {"tools": [oai_tool]}
if enforce_tool_usage:
llm_kwargs["tool_choice"] = {
"type": "function",
@ -462,7 +463,7 @@ def _create_openai_tools_runnable(
def _get_openai_tool_output_parser(
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
tool: Union[dict[str, Any], type[BaseModel], Callable],
*,
first_tool_only: bool = False,
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
@ -479,7 +480,7 @@ def _get_openai_tool_output_parser(
def get_openai_output_parser(
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions.
@ -496,7 +497,7 @@ def get_openai_output_parser(
"""
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
pydantic_schema: Union[dict, type[BaseModel]] = {
convert_to_openai_function(fn)["name"]: fn for fn in functions
}
else:
@ -510,7 +511,7 @@ def get_openai_output_parser(
def _create_openai_json_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable,
prompt: Optional[BasePromptTemplate] = None,
*,
@ -537,7 +538,7 @@ def _create_openai_json_runnable(
def _create_openai_functions_structured_output_runnable(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
output_schema: Union[dict[str, Any], type[BaseModel]],
llm: Runnable,
prompt: Optional[BasePromptTemplate] = None,
*,

View File

@ -1,6 +1,7 @@
"""Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol
from collections.abc import Mapping
from typing import Any, Optional, Protocol
from langchain_core.callbacks import Callbacks
from langchain_core.language_models import BaseLanguageModel

View File

@ -2,7 +2,8 @@
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from collections.abc import Awaitable
from typing import Any, Callable, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
@ -26,13 +27,13 @@ class TransformChain(Chain):
output_variables["entities"], transform=func())
"""
input_variables: List[str]
input_variables: list[str]
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
output_variables: list[str]
"""The keys returned by the transform's output dictionary."""
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = (
atransform_cb: Optional[Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]] = (
Field(None, alias="atransform")
)
"""The async coroutine transform function."""
@ -47,7 +48,7 @@ class TransformChain(Chain):
logger.warning(msg)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Expect input keys.
:meta private:
@ -55,7 +56,7 @@ class TransformChain(Chain):
return self.input_variables
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Return output keys.
:meta private:
@ -64,16 +65,16 @@ class TransformChain(Chain):
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
) -> dict[str, str]:
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
inputs: dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:

View File

@ -1,19 +1,13 @@
from __future__ import annotations
import warnings
from collections.abc import AsyncIterator, Iterator, Sequence
from importlib import util
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
overload,
@ -73,7 +67,7 @@ def init_chat_model(
model: Optional[str] = None,
*,
model_provider: Optional[str] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ...,
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
config_prefix: Optional[str] = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@ -87,7 +81,7 @@ def init_chat_model(
*,
model_provider: Optional[str] = None,
configurable_fields: Optional[
Union[Literal["any"], List[str], Tuple[str, ...]]
Union[Literal["any"], list[str], tuple[str, ...]]
] = None,
config_prefix: Optional[str] = None,
**kwargs: Any,
@ -514,7 +508,7 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
return None
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
if (
not model_provider
and ":" in model
@ -554,12 +548,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
self,
*,
default_config: Optional[dict] = None,
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any",
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (),
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Union[Literal["any"], List[str]] = (
self._configurable_fields: Union[Literal["any"], list[str]] = (
configurable_fields
if configurable_fields == "any"
else list(configurable_fields)
@ -569,7 +563,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list(
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
queued_declarative_operations
)
@ -670,7 +664,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
list[AnyMessage],
]
def invoke(
@ -708,12 +702,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Any]:
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
@ -731,12 +725,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
inputs: list[LanguageModelInput],
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Any]:
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
@ -759,7 +753,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Any, Exception]]]:
) -> Iterator[tuple[int, Union[Any, Exception]]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
@ -782,7 +776,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[Tuple[int, Any]]:
) -> AsyncIterator[tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
@ -808,8 +802,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Any]:
for x in self._model(config).transform(input, config=config, **kwargs):
yield x
yield from self._model(config).transform(input, config=config, **kwargs)
async def atransform(
self,
@ -915,13 +908,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
return self.__getattr__("with_structured_output")(schema, **kwargs)

View File

@ -1,6 +1,6 @@
import functools
from importlib import util
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Union
from langchain_core._api import beta
from langchain_core.embeddings import Embeddings
@ -25,7 +25,7 @@ def _get_provider_list() -> str:
)
def _parse_model_string(model_name: str) -> Tuple[str, str]:
def _parse_model_string(model_name: str) -> tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
@ -78,7 +78,7 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]:
def _infer_model_and_provider(
model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]:
) -> tuple[str, str]:
if not model.strip():
raise ValueError("Model name cannot be empty")
if provider is None and ":" in model:
@ -122,7 +122,7 @@ def init_embeddings(
*,
provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]:
) -> Union[Embeddings, Runnable[Any, list[float]]]:
"""Initialize an embeddings model from a model name and optional provider.
**Note:** Must have the integration package corresponding to the model provider

View File

@ -12,8 +12,9 @@ from __future__ import annotations
import hashlib
import json
import uuid
from collections.abc import Sequence
from functools import partial
from typing import Callable, List, Optional, Sequence, Union, cast
from typing import Callable, Optional, Union, cast
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore, ByteStore
@ -45,9 +46,9 @@ def _value_serializer(value: Sequence[float]) -> bytes:
return json.dumps(value).encode()
def _value_deserializer(serialized_value: bytes) -> List[float]:
def _value_deserializer(serialized_value: bytes) -> list[float]:
"""Deserialize a value."""
return cast(List[float], json.loads(serialized_value.decode()))
return cast(list[float], json.loads(serialized_value.decode()))
class CacheBackedEmbeddings(Embeddings):
@ -88,10 +89,10 @@ class CacheBackedEmbeddings(Embeddings):
def __init__(
self,
underlying_embeddings: Embeddings,
document_embedding_store: BaseStore[str, List[float]],
document_embedding_store: BaseStore[str, list[float]],
*,
batch_size: Optional[int] = None,
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
) -> None:
"""Initialize the embedder.
@ -108,7 +109,7 @@ class CacheBackedEmbeddings(Embeddings):
self.underlying_embeddings = underlying_embeddings
self.batch_size = batch_size
def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
@ -121,10 +122,10 @@ class CacheBackedEmbeddings(Embeddings):
Returns:
A list of embeddings for the given texts.
"""
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
texts
)
all_missing_indices: List[int] = [
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
@ -138,10 +139,10 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
list[list[float]], vectors
) # Nones should have been resolved by now
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts.
The method first checks the cache for the embeddings.
@ -154,10 +155,10 @@ class CacheBackedEmbeddings(Embeddings):
Returns:
A list of embeddings for the given texts.
"""
vectors: List[
Union[List[float], None]
vectors: list[
Union[list[float], None]
] = await self.document_embedding_store.amget(texts)
all_missing_indices: List[int] = [
all_missing_indices: list[int] = [
i for i, vector in enumerate(vectors) if vector is None
]
@ -175,10 +176,10 @@ class CacheBackedEmbeddings(Embeddings):
vectors[index] = updated_vector
return cast(
List[List[float]], vectors
list[list[float]], vectors
) # Nones should have been resolved by now
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> list[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
@ -201,7 +202,7 @@ class CacheBackedEmbeddings(Embeddings):
self.query_embedding_store.mset([(text, vector)])
return vector
async def aembed_query(self, text: str) -> List[float]:
async def aembed_query(self, text: str) -> list[float]:
"""Embed query text.
By default, this method does not cache queries. To enable caching, set the
@ -250,7 +251,7 @@ class CacheBackedEmbeddings(Embeddings):
"""
namespace = namespace
key_encoder = _create_key_encoder(namespace)
document_embedding_store = EncoderBackedStore[str, List[float]](
document_embedding_store = EncoderBackedStore[str, list[float]](
document_embedding_cache,
key_encoder,
_value_serializer,
@ -261,7 +262,7 @@ class CacheBackedEmbeddings(Embeddings):
elif query_embedding_cache is False:
query_embedding_store = None
else:
query_embedding_store = EncoderBackedStore[str, List[float]](
query_embedding_store = EncoderBackedStore[str, list[float]](
query_embedding_cache,
key_encoder,
_value_serializer,

View File

@ -6,13 +6,10 @@ chain (LLMChain) to generate the reasoning and scores.
"""
import re
from collections.abc import Sequence
from typing import (
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
cast,
@ -145,7 +142,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
# 0
"""
agent_tools: Optional[List[BaseTool]] = None
agent_tools: Optional[list[BaseTool]] = None
"""A list of tools available to the agent."""
eval_chain: LLMChain
"""The language model chain used for evaluation."""
@ -184,7 +181,7 @@ Description: {tool.description}"""
@staticmethod
def get_agent_trajectory(
steps: Union[str, Sequence[Tuple[AgentAction, str]]],
steps: Union[str, Sequence[tuple[AgentAction, str]]],
) -> str:
"""Get the agent trajectory as a formatted string.
@ -263,7 +260,7 @@ The following is the expected answer. Use this to measure correctness:
)
@property
def input_keys(self) -> List[str]:
def input_keys(self) -> list[str]:
"""Get the input keys for the chain.
Returns:
@ -272,7 +269,7 @@ The following is the expected answer. Use this to measure correctness:
return ["question", "agent_trajectory", "answer", "reference"]
@property
def output_keys(self) -> List[str]:
def output_keys(self) -> list[str]:
"""Get the output keys for the chain.
Returns:
@ -280,16 +277,16 @@ The following is the expected answer. Use this to measure correctness:
"""
return ["score", "reasoning"]
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
"""Validate and prep inputs."""
inputs["reference"] = self._format_reference(inputs.get("reference"))
return super().prep_inputs(inputs)
def _call(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run the chain and generate the output.
Args:
@ -311,9 +308,9 @@ The following is the expected answer. Use this to measure correctness:
async def _acall(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Run the chain and generate the output.
Args:
@ -338,11 +335,11 @@ The following is the expected answer. Use this to measure correctness:
*,
prediction: str,
input: str,
agent_trajectory: Sequence[Tuple[AgentAction, str]],
agent_trajectory: Sequence[tuple[AgentAction, str]],
reference: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
@ -380,11 +377,11 @@ The following is the expected answer. Use this to measure correctness:
*,
prediction: str,
input: str,
agent_trajectory: Sequence[Tuple[AgentAction, str]],
agent_trajectory: Sequence[tuple[AgentAction, str]],
reference: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import re
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from langchain_core.callbacks.manager import Callbacks
from langchain_core.language_models import BaseLanguageModel
@ -49,7 +49,7 @@ _SUPPORTED_CRITERIA = {
def resolve_pairwise_criteria(
criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]],
criteria: Optional[Union[CRITERIA_TYPE, str, list[CRITERIA_TYPE]]],
) -> dict:
"""Resolve the criteria for the pairwise evaluator.
@ -113,7 +113,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): # type: ignore[
"""
return "pairwise_string_result"
def parse(self, text: str) -> Dict[str, Any]:
def parse(self, text: str) -> dict[str, Any]:
"""Parse the output text.
Args:
@ -314,8 +314,8 @@ Performance may be significantly worse with other models."
input: Optional[str] = None,
reference: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:
@ -356,8 +356,8 @@ Performance may be significantly worse with other models."
reference: Optional[str] = None,
input: Optional[str] = None,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
include_run_info: bool = False,
**kwargs: Any,
) -> dict:

Some files were not shown because too many files have changed in this diff Show More