mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 07:08:03 +00:00
langchain[lint]: use pyupgrade
to get to 3.9 standards (#30782)
This commit is contained in:
parent
d9b628e764
commit
48affc498b
libs/langchain/langchain
_api
agents
agent.pyagent_iterator.py
agent_toolkits
chat
conversational
conversational_chat
format_scratchpad
initialize.pyjson_chat
loading.pymrkl
openai_assistant
openai_functions_agent
openai_functions_multi_agent
openai_tools
output_parsers
react
schema.pyself_ask_with_search
structured_chat
tool_calling_agent
tools.pytypes.pyutils.pyxml
callbacks
chains
api
base.pycombine_documents
constitutional_ai
conversation
conversational_retrieval
elasticsearch_database
example_generator.pyflare
hyde
llm.pyllm_checker
llm_math
llm_summarization_checker
loading.pymapreduce.pymoderation.pynatbot
openai_functions
openai_tools
prompt_selector.pyqa_generation
qa_with_sources
query_constructor
question_answering
retrieval.pyretrieval_qa
router
sequential.pysql_database
structured_output
summarize
transform.pychat_models
embeddings
evaluation
@ -1,5 +1,5 @@
|
||||
import importlib
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain_core._api import internal, warn_deprecated
|
||||
|
||||
@ -15,8 +15,8 @@ ALLOWED_TOP_LEVEL_PKGS = {
|
||||
def create_importer(
|
||||
package: str,
|
||||
*,
|
||||
module_lookup: Optional[Dict[str, str]] = None,
|
||||
deprecated_lookups: Optional[Dict[str, str]] = None,
|
||||
module_lookup: Optional[dict[str, str]] = None,
|
||||
deprecated_lookups: Optional[dict[str, str]] = None,
|
||||
fallback_module: Optional[str] = None,
|
||||
) -> Callable[[str], Any]:
|
||||
"""Create a function that helps retrieve objects from their new locations.
|
||||
|
@ -3,21 +3,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -62,17 +58,17 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Base Single Action Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
def return_values(self) -> list[str]:
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -91,7 +87,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -109,7 +105,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
@ -118,7 +114,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
@ -171,7 +167,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Return Identifier of an agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent.
|
||||
|
||||
Returns:
|
||||
@ -223,7 +219,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||
"""Return logging kwargs for tool run."""
|
||||
return {}
|
||||
|
||||
@ -232,11 +228,11 @@ class BaseMultiActionAgent(BaseModel):
|
||||
"""Base Multi Action Agent class."""
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
def return_values(self) -> list[str]:
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||
"""Get allowed tools.
|
||||
|
||||
Returns:
|
||||
@ -247,10 +243,10 @@ class BaseMultiActionAgent(BaseModel):
|
||||
@abstractmethod
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -266,10 +262,10 @@ class BaseMultiActionAgent(BaseModel):
|
||||
@abstractmethod
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Async given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -284,7 +280,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
@ -293,7 +289,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
@ -323,7 +319,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
"""Return Identifier of an agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().model_dump()
|
||||
try:
|
||||
@ -371,7 +367,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||
"""Return logging kwargs for tool run."""
|
||||
|
||||
return {}
|
||||
@ -386,7 +382,7 @@ class AgentOutputParser(BaseOutputParser[Union[AgentAction, AgentFinish]]):
|
||||
|
||||
|
||||
class MultiActionAgentOutputParser(
|
||||
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
|
||||
BaseOutputParser[Union[list[AgentAction], AgentFinish]]
|
||||
):
|
||||
"""Base class for parsing agent output into agent actions/finish.
|
||||
|
||||
@ -394,7 +390,7 @@ class MultiActionAgentOutputParser(
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Parse text into agent actions/finish.
|
||||
|
||||
Args:
|
||||
@ -411,8 +407,8 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
|
||||
runnable: Runnable[dict, Union[AgentAction, AgentFinish]]
|
||||
"""Runnable to call to get agent action."""
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
input_keys_arg: list[str] = []
|
||||
return_keys_arg: list[str] = []
|
||||
stream_runnable: bool = True
|
||||
"""Whether to stream from the runnable or not.
|
||||
|
||||
@ -427,18 +423,18 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
)
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
def return_values(self) -> list[str]:
|
||||
"""Return values of the agent."""
|
||||
return self.return_keys_arg
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys."""
|
||||
return self.input_keys_arg
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -474,7 +470,7 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
@ -518,10 +514,10 @@ class RunnableAgent(BaseSingleActionAgent):
|
||||
class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
"""Agent powered by Runnables."""
|
||||
|
||||
runnable: Runnable[dict, Union[List[AgentAction], AgentFinish]]
|
||||
runnable: Runnable[dict, Union[list[AgentAction], AgentFinish]]
|
||||
"""Runnable to call to get agent actions."""
|
||||
input_keys_arg: List[str] = []
|
||||
return_keys_arg: List[str] = []
|
||||
input_keys_arg: list[str] = []
|
||||
return_keys_arg: list[str] = []
|
||||
stream_runnable: bool = True
|
||||
"""Whether to stream from the runnable or not.
|
||||
|
||||
@ -536,12 +532,12 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
)
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
def return_values(self) -> list[str]:
|
||||
"""Return values of the agent."""
|
||||
return self.return_keys_arg
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
Returns:
|
||||
@ -551,11 +547,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
List[AgentAction],
|
||||
list[AgentAction],
|
||||
AgentFinish,
|
||||
]:
|
||||
"""Based on past history and current inputs, decide what to do.
|
||||
@ -590,11 +586,11 @@ class RunnableMultiActionAgent(BaseMultiActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
List[AgentAction],
|
||||
list[AgentAction],
|
||||
AgentFinish,
|
||||
]:
|
||||
"""Async based on past history and current inputs, decide what to do.
|
||||
@ -644,11 +640,11 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
"""LLMChain to use for agent."""
|
||||
output_parser: AgentOutputParser
|
||||
"""Output parser to use for agent."""
|
||||
stop: List[str]
|
||||
stop: list[str]
|
||||
"""List of strings to stop on."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
Returns:
|
||||
@ -656,7 +652,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
"""
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
@ -664,7 +660,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -689,7 +685,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -712,7 +708,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
)
|
||||
return self.output_parser.parse(output)
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||
"""Return logging kwargs for tool run."""
|
||||
return {
|
||||
"llm_prefix": "",
|
||||
@ -737,21 +733,21 @@ class Agent(BaseSingleActionAgent):
|
||||
"""LLMChain to use for agent."""
|
||||
output_parser: AgentOutputParser
|
||||
"""Output parser to use for agent."""
|
||||
allowed_tools: Optional[List[str]] = None
|
||||
allowed_tools: Optional[list[str]] = None
|
||||
"""Allowed tools for the agent. If None, all tools are allowed."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
return _dict
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
def get_allowed_tools(self) -> Optional[list[str]]:
|
||||
"""Get allowed tools."""
|
||||
return self.allowed_tools
|
||||
|
||||
@property
|
||||
def return_values(self) -> List[str]:
|
||||
def return_values(self) -> list[str]:
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
@ -767,15 +763,15 @@ class Agent(BaseSingleActionAgent):
|
||||
raise ValueError("fix_text not implemented for this agent.")
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
def _stop(self) -> list[str]:
|
||||
return [
|
||||
f"\n{self.observation_prefix.rstrip()}",
|
||||
f"\n\t{self.observation_prefix.rstrip()}",
|
||||
]
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> Union[str, List[BaseMessage]]:
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> Union[str, list[BaseMessage]]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
@ -785,7 +781,7 @@ class Agent(BaseSingleActionAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -806,7 +802,7 @@ class Agent(BaseSingleActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -827,8 +823,8 @@ class Agent(BaseSingleActionAgent):
|
||||
return agent_output
|
||||
|
||||
def get_full_inputs(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any
|
||||
) -> builtins.dict[str, Any]:
|
||||
"""Create the full inputs for the LLMChain from intermediate steps.
|
||||
|
||||
Args:
|
||||
@ -845,7 +841,7 @@ class Agent(BaseSingleActionAgent):
|
||||
return full_inputs
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
@ -957,7 +953,7 @@ class Agent(BaseSingleActionAgent):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
@ -1009,7 +1005,7 @@ class Agent(BaseSingleActionAgent):
|
||||
f"got {early_stopping_method}"
|
||||
)
|
||||
|
||||
def tool_run_logging_kwargs(self) -> Dict:
|
||||
def tool_run_logging_kwargs(self) -> builtins.dict:
|
||||
"""Return logging kwargs for tool run."""
|
||||
return {
|
||||
"llm_prefix": self.llm_prefix,
|
||||
@ -1040,7 +1036,7 @@ class ExceptionTool(BaseTool): # type: ignore[override]
|
||||
return query
|
||||
|
||||
|
||||
NextStepOutput = List[Union[AgentFinish, AgentAction, AgentStep]]
|
||||
NextStepOutput = list[Union[AgentFinish, AgentAction, AgentStep]]
|
||||
RunnableAgentType = Union[RunnableAgent, RunnableMultiActionAgent]
|
||||
|
||||
|
||||
@ -1086,7 +1082,7 @@ class AgentExecutor(Chain):
|
||||
as an observation.
|
||||
"""
|
||||
trim_intermediate_steps: Union[
|
||||
int, Callable[[List[Tuple[AgentAction, str]]], List[Tuple[AgentAction, str]]]
|
||||
int, Callable[[list[tuple[AgentAction, str]]], list[tuple[AgentAction, str]]]
|
||||
] = -1
|
||||
"""How to trim the intermediate steps before returning them.
|
||||
Defaults to -1, which means no trimming.
|
||||
@ -1144,7 +1140,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_runnable_agent(cls, values: Dict) -> Any:
|
||||
def validate_runnable_agent(cls, values: dict) -> Any:
|
||||
"""Convert runnable to agent if passed in.
|
||||
|
||||
Args:
|
||||
@ -1160,7 +1156,7 @@ class AgentExecutor(Chain):
|
||||
except Exception as _:
|
||||
multi_action = False
|
||||
else:
|
||||
multi_action = output_type == Union[List[AgentAction], AgentFinish]
|
||||
multi_action = output_type == Union[list[AgentAction], AgentFinish]
|
||||
|
||||
stream_runnable = values.pop("stream_runnable", True)
|
||||
if multi_action:
|
||||
@ -1239,7 +1235,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the input keys.
|
||||
|
||||
:meta private:
|
||||
@ -1247,7 +1243,7 @@ class AgentExecutor(Chain):
|
||||
return self._action_agent.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
@ -1284,7 +1280,7 @@ class AgentExecutor(Chain):
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if run_manager:
|
||||
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
|
||||
final_output = output.return_values
|
||||
@ -1297,7 +1293,7 @@ class AgentExecutor(Chain):
|
||||
output: AgentFinish,
|
||||
intermediate_steps: list,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if run_manager:
|
||||
await run_manager.on_agent_finish(
|
||||
output, color="green", verbose=self.verbose
|
||||
@ -1309,7 +1305,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _consume_next_step(
|
||||
self, values: NextStepOutput
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||
if isinstance(values[-1], AgentFinish):
|
||||
assert len(values) == 1
|
||||
return values[-1]
|
||||
@ -1320,12 +1316,12 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _take_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||
return self._consume_next_step(
|
||||
[
|
||||
a
|
||||
@ -1341,10 +1337,10 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _iter_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
@ -1404,7 +1400,7 @@ class AgentExecutor(Chain):
|
||||
yield output
|
||||
return
|
||||
|
||||
actions: List[AgentAction]
|
||||
actions: list[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
@ -1418,8 +1414,8 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _perform_agent_action(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
agent_action: AgentAction,
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> AgentStep:
|
||||
@ -1457,12 +1453,12 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _atake_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
|
||||
) -> Union[AgentFinish, list[tuple[AgentAction, str]]]:
|
||||
return self._consume_next_step(
|
||||
[
|
||||
a
|
||||
@ -1478,10 +1474,10 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _aiter_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> AsyncIterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
@ -1539,7 +1535,7 @@ class AgentExecutor(Chain):
|
||||
yield output
|
||||
return
|
||||
|
||||
actions: List[AgentAction]
|
||||
actions: list[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
@ -1563,8 +1559,8 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _aperform_agent_action(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
name_to_tool_map: dict[str, BaseTool],
|
||||
color_mapping: dict[str, str],
|
||||
agent_action: AgentAction,
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> AgentStep:
|
||||
@ -1604,9 +1600,9 @@ class AgentExecutor(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@ -1614,7 +1610,7 @@ class AgentExecutor(Chain):
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
iterations = 0
|
||||
time_elapsed = 0.0
|
||||
@ -1651,9 +1647,9 @@ class AgentExecutor(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Async run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
@ -1661,7 +1657,7 @@ class AgentExecutor(Chain):
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
intermediate_steps: list[tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
iterations = 0
|
||||
time_elapsed = 0.0
|
||||
@ -1712,7 +1708,7 @@ class AgentExecutor(Chain):
|
||||
)
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
self, next_step_output: tuple[AgentAction, str]
|
||||
) -> Optional[AgentFinish]:
|
||||
"""Check if the tool is a returning tool."""
|
||||
agent_action, observation = next_step_output
|
||||
@ -1730,8 +1726,8 @@ class AgentExecutor(Chain):
|
||||
return None
|
||||
|
||||
def _prepare_intermediate_steps(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> List[Tuple[AgentAction, str]]:
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> list[tuple[AgentAction, str]]:
|
||||
if (
|
||||
isinstance(self.trim_intermediate_steps, int)
|
||||
and self.trim_intermediate_steps > 0
|
||||
@ -1744,7 +1740,7 @@ class AgentExecutor(Chain):
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Union[Dict[str, Any], Any],
|
||||
input: Union[dict[str, Any], Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[AddableDict]:
|
||||
@ -1770,12 +1766,11 @@ class AgentExecutor(Chain):
|
||||
yield_actions=True,
|
||||
**kwargs,
|
||||
)
|
||||
for step in iterator:
|
||||
yield step
|
||||
yield from iterator
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Union[Dict[str, Any], Any],
|
||||
input: Union[dict[str, Any], Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AddableDict]:
|
||||
|
@ -3,15 +3,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
@ -53,7 +49,7 @@ class AgentExecutorIterator:
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
run_id: Optional[UUID] = None,
|
||||
include_run_info: bool = False,
|
||||
@ -90,17 +86,17 @@ class AgentExecutorIterator:
|
||||
self.yield_actions = yield_actions
|
||||
self.reset()
|
||||
|
||||
_inputs: Dict[str, str]
|
||||
_inputs: dict[str, str]
|
||||
callbacks: Callbacks
|
||||
tags: Optional[list[str]]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
metadata: Optional[dict[str, Any]]
|
||||
run_name: Optional[str]
|
||||
run_id: Optional[UUID]
|
||||
include_run_info: bool
|
||||
yield_actions: bool
|
||||
|
||||
@property
|
||||
def inputs(self) -> Dict[str, str]:
|
||||
def inputs(self) -> dict[str, str]:
|
||||
"""The inputs to the AgentExecutor."""
|
||||
return self._inputs
|
||||
|
||||
@ -120,12 +116,12 @@ class AgentExecutorIterator:
|
||||
self.inputs = self.inputs
|
||||
|
||||
@property
|
||||
def name_to_tool_map(self) -> Dict[str, BaseTool]:
|
||||
def name_to_tool_map(self) -> dict[str, BaseTool]:
|
||||
"""A mapping of tool names to tools."""
|
||||
return {tool.name: tool for tool in self.agent_executor.tools}
|
||||
|
||||
@property
|
||||
def color_mapping(self) -> Dict[str, str]:
|
||||
def color_mapping(self) -> dict[str, str]:
|
||||
"""A mapping of tool names to colors."""
|
||||
return get_color_mapping(
|
||||
[tool.name for tool in self.agent_executor.tools],
|
||||
@ -156,7 +152,7 @@ class AgentExecutorIterator:
|
||||
|
||||
def make_final_outputs(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
outputs: dict[str, Any],
|
||||
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun],
|
||||
) -> AddableDict:
|
||||
# have access to intermediate steps by design in iterator,
|
||||
@ -171,7 +167,7 @@ class AgentExecutorIterator:
|
||||
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return prepared_outputs
|
||||
|
||||
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]:
|
||||
def __iter__(self: AgentExecutorIterator) -> Iterator[AddableDict]:
|
||||
logger.debug("Initialising AgentExecutorIterator")
|
||||
self.reset()
|
||||
callback_manager = CallbackManager.configure(
|
||||
@ -311,7 +307,7 @@ class AgentExecutorIterator:
|
||||
|
||||
def _process_next_step_output(
|
||||
self,
|
||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> AddableDict:
|
||||
"""
|
||||
@ -339,7 +335,7 @@ class AgentExecutorIterator:
|
||||
|
||||
async def _aprocess_next_step_output(
|
||||
self,
|
||||
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]],
|
||||
next_step_output: Union[AgentFinish, list[tuple[AgentAction, str]]],
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> AddableDict:
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.memory import BaseMemory
|
||||
@ -26,7 +26,7 @@ def _get_default_system_message() -> SystemMessage:
|
||||
|
||||
def create_conversational_retrieval_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: List[BaseTool],
|
||||
tools: list[BaseTool],
|
||||
remember_intermediate_steps: bool = True,
|
||||
memory_key: str = "chat_history",
|
||||
system_message: Optional[SystemMessage] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""VectorStore agent."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks.base import BaseCallbackManager
|
||||
@ -36,7 +36,7 @@ def create_vectorstore_agent(
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = PREFIX,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a VectorStore agent from an LLM and tools.
|
||||
@ -129,7 +129,7 @@ def create_vectorstore_router_agent(
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = ROUTER_PREFIX,
|
||||
verbose: bool = False,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
agent_executor_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a VectorStore router agent from an LLM and tools.
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Toolkit for interacting with a vector store."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools.base import BaseToolkit
|
||||
@ -31,7 +29,7 @@ class VectorStoreToolkit(BaseToolkit):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
try:
|
||||
from langchain_community.tools.vectorstore.tool import (
|
||||
@ -66,16 +64,16 @@ class VectorStoreToolkit(BaseToolkit):
|
||||
class VectorStoreRouterToolkit(BaseToolkit):
|
||||
"""Toolkit for routing between Vector Stores."""
|
||||
|
||||
vectorstores: List[VectorStoreInfo] = Field(exclude=True)
|
||||
vectorstores: list[VectorStoreInfo] = Field(exclude=True)
|
||||
llm: BaseLanguageModel
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
|
||||
def get_tools(self) -> List[BaseTool]:
|
||||
def get_tools(self) -> list[BaseTool]:
|
||||
"""Get the tools in the toolkit."""
|
||||
tools: List[BaseTool] = []
|
||||
tools: list[BaseTool] = []
|
||||
try:
|
||||
from langchain_community.tools.vectorstore.tool import (
|
||||
VectorStoreQATool,
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction
|
||||
@ -48,7 +49,7 @@ class ChatAgent(Agent):
|
||||
return "Thought:"
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
@ -72,7 +73,7 @@ class ChatAgent(Agent):
|
||||
validate_tools_single_input(class_name=cls.__name__, tools=tools)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
def _stop(self) -> list[str]:
|
||||
return ["Observation:"]
|
||||
|
||||
@classmethod
|
||||
@ -83,7 +84,7 @@ class ChatAgent(Agent):
|
||||
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
||||
human_message: str = HUMAN_MESSAGE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt from a list of tools.
|
||||
|
||||
@ -132,7 +133,7 @@ class ChatAgent(Agent):
|
||||
system_message_suffix: str = SYSTEM_MESSAGE_SUFFIX,
|
||||
human_message: str = HUMAN_MESSAGE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools.
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Pattern, Union
|
||||
from re import Pattern
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
@ -71,7 +72,7 @@ class ConversationalAgent(Agent):
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
ai_prefix: str = "AI",
|
||||
human_prefix: str = "Human",
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero-shot agent.
|
||||
|
||||
@ -120,7 +121,7 @@ class ConversationalAgent(Agent):
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
ai_prefix: str = "AI",
|
||||
human_prefix: str = "Human",
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools.
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction
|
||||
@ -77,7 +78,7 @@ class ConversationalChatAgent(Agent):
|
||||
tools: Sequence[BaseTool],
|
||||
system_message: str = PREFIX,
|
||||
human_message: str = SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
output_parser: Optional[BaseOutputParser] = None,
|
||||
) -> BasePromptTemplate:
|
||||
"""Create a prompt for the agent.
|
||||
@ -116,10 +117,10 @@ class ConversationalChatAgent(Agent):
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> List[BaseMessage]:
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> list[BaseMessage]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts: List[BaseMessage] = []
|
||||
thoughts: list[BaseMessage] = []
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts.append(AIMessage(content=action.log))
|
||||
human_message = HumanMessage(
|
||||
@ -137,7 +138,7 @@ class ConversationalChatAgent(Agent):
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
system_message: str = PREFIX,
|
||||
human_message: str = SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools.
|
||||
|
@ -1,10 +1,8 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
|
||||
|
||||
def format_log_to_str(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
observation_prefix: str = "Observation: ",
|
||||
llm_prefix: str = "Thought: ",
|
||||
) -> str:
|
||||
|
@ -1,13 +1,11 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
def format_log_to_messages(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
template_tool_response: str = "{observation}",
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process.
|
||||
|
||||
Args:
|
||||
@ -18,7 +16,7 @@ def format_log_to_messages(
|
||||
Returns:
|
||||
List[BaseMessage]: The scratchpad.
|
||||
"""
|
||||
thoughts: List[BaseMessage] = []
|
||||
thoughts: list[BaseMessage] = []
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts.append(AIMessage(content=action.log))
|
||||
human_message = HumanMessage(
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog
|
||||
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
||||
@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert an agent action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agent action.
|
||||
@ -54,8 +54,8 @@ def _create_function_message(
|
||||
|
||||
|
||||
def format_to_openai_function_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
intermediate_steps: Sequence[tuple[AgentAction, str]],
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into FunctionMessages.
|
||||
|
||||
Args:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.messages import (
|
||||
@ -40,8 +40,8 @@ def _create_tool_message(
|
||||
|
||||
|
||||
def format_to_tool_messages(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
intermediate_steps: Sequence[tuple[AgentAction, str]],
|
||||
) -> list[BaseMessage]:
|
||||
"""Convert (AgentAction, tool output) tuples into ToolMessages.
|
||||
|
||||
Args:
|
||||
|
@ -1,10 +1,8 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
|
||||
|
||||
def format_xml(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
) -> str:
|
||||
"""Format the intermediate steps as XML.
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Load agent."""
|
||||
|
||||
from typing import Any, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import List, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@ -15,7 +16,7 @@ def create_json_chat_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
prompt: ChatPromptTemplate,
|
||||
stop_sequence: Union[bool, List[str]] = True,
|
||||
stop_sequence: Union[bool, list[str]] = True,
|
||||
tools_renderer: ToolsRenderer = render_text_description,
|
||||
template_tool_response: str = TEMPLATE_TOOL_RESPONSE,
|
||||
) -> Runnable:
|
||||
|
@ -3,7 +3,7 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import yaml
|
||||
from langchain_core._api import deprecated
|
||||
@ -20,7 +20,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
|
||||
|
||||
|
||||
def _load_agent_from_tools(
|
||||
config: dict, llm: BaseLanguageModel, tools: List[Tool], **kwargs: Any
|
||||
config: dict, llm: BaseLanguageModel, tools: list[Tool], **kwargs: Any
|
||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
config_type = config.pop("_type")
|
||||
if config_type not in AGENT_TO_CLASS:
|
||||
@ -35,7 +35,7 @@ def _load_agent_from_tools(
|
||||
def load_agent_from_config(
|
||||
config: dict,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
tools: Optional[list[Tool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
"""Load agent from Config Dict.
|
||||
@ -130,7 +130,7 @@ def _load_agent_from_file(
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix[1:] == "yaml":
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.")
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List, NamedTuple, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import BaseCallbackManager
|
||||
@ -83,7 +84,7 @@ class ZeroShotAgent(Agent):
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
@ -118,7 +119,7 @@ class ZeroShotAgent(Agent):
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools.
|
||||
@ -183,7 +184,7 @@ class MRKLChain(AgentExecutor):
|
||||
|
||||
@classmethod
|
||||
def from_chains(
|
||||
cls, llm: BaseLanguageModel, chains: List[ChainConfig], **kwargs: Any
|
||||
cls, llm: BaseLanguageModel, chains: list[ChainConfig], **kwargs: Any
|
||||
) -> AgentExecutor:
|
||||
"""User-friendly way to initialize the MRKL chain.
|
||||
|
||||
|
@ -2,18 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from time import sleep
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -111,7 +107,7 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
|
||||
|
||||
|
||||
def _is_assistants_builtin_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
|
||||
) -> bool:
|
||||
"""Determine if tool corresponds to OpenAI Assistants built-in."""
|
||||
assistants_builtin_tools = ("code_interpreter", "file_search")
|
||||
@ -123,8 +119,8 @@ def _is_assistants_builtin_tool(
|
||||
|
||||
|
||||
def _get_assistants_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> Dict[str, Any]:
|
||||
tool: Union[dict[str, Any], type[BaseModel], Callable, BaseTool],
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
Note that OpenAI assistants supports several built-in tools,
|
||||
@ -137,14 +133,14 @@ def _get_assistants_tool(
|
||||
|
||||
|
||||
OutputType = Union[
|
||||
List[OpenAIAssistantAction],
|
||||
list[OpenAIAssistantAction],
|
||||
OpenAIAssistantFinish,
|
||||
List["ThreadMessage"],
|
||||
List["RequiredActionFunctionToolCall"],
|
||||
list["ThreadMessage"],
|
||||
list["RequiredActionFunctionToolCall"],
|
||||
]
|
||||
|
||||
|
||||
class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
||||
"""Run an OpenAI Assistant.
|
||||
|
||||
Example using OpenAI tools:
|
||||
@ -498,7 +494,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
return response
|
||||
|
||||
def _parse_intermediate_steps(
|
||||
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
|
||||
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
|
||||
) -> dict:
|
||||
last_action, last_output = intermediate_steps[-1]
|
||||
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
||||
@ -652,7 +648,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
return run
|
||||
|
||||
async def _aparse_intermediate_steps(
|
||||
self, intermediate_steps: List[Tuple[OpenAIAssistantAction, str]]
|
||||
self, intermediate_steps: list[tuple[OpenAIAssistantAction, str]]
|
||||
) -> dict:
|
||||
last_action, last_output = intermediate_steps[-1]
|
||||
run = self._wait_for_run(last_action.run_id, last_action.thread_id)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Memory used to save agent output AND intermediate steps."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
@ -43,19 +43,19 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
|
||||
format_as_tools: bool = False
|
||||
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
def buffer(self) -> list[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
return self.chat_memory.messages
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
def memory_variables(self) -> list[str]:
|
||||
"""Always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer.
|
||||
|
||||
Args:
|
||||
@ -74,7 +74,7 @@ class AgentTokenBufferMemory(BaseChatMemory): # type: ignore[override]
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, Any]) -> None:
|
||||
def save_context(self, inputs: dict[str, Any], outputs: dict[str, Any]) -> None:
|
||||
"""Save context from this conversation to buffer. Pruned.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
@ -51,11 +52,11 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
llm: BaseLanguageModel
|
||||
tools: Sequence[BaseTool]
|
||||
prompt: BasePromptTemplate
|
||||
output_parser: Type[OpenAIFunctionsAgentOutputParser] = (
|
||||
output_parser: type[OpenAIFunctionsAgentOutputParser] = (
|
||||
OpenAIFunctionsAgentOutputParser
|
||||
)
|
||||
|
||||
def get_allowed_tools(self) -> List[str]:
|
||||
def get_allowed_tools(self) -> list[str]:
|
||||
"""Get allowed tools."""
|
||||
return [t.name for t in self.tools]
|
||||
|
||||
@ -81,19 +82,19 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Get input keys. Input refers to user input here."""
|
||||
return ["input"]
|
||||
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
def functions(self) -> list[dict]:
|
||||
"""Get functions."""
|
||||
|
||||
return [dict(convert_to_openai_function(t)) for t in self.tools]
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
with_functions: bool = True,
|
||||
**kwargs: Any,
|
||||
@ -135,7 +136,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -168,7 +169,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
@ -213,7 +214,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create prompt for this agent.
|
||||
|
||||
@ -227,7 +228,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
A prompt template to pass into this agent.
|
||||
"""
|
||||
_prompts = extra_prompt_messages or []
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
if system_message:
|
||||
messages = [system_message]
|
||||
else:
|
||||
@ -248,7 +249,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
|
@ -1,8 +1,9 @@
|
||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
@ -34,7 +35,7 @@ from langchain.agents.format_scratchpad.openai_functions import (
|
||||
_FunctionsAgentAction = AgentActionMessageLog
|
||||
|
||||
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message."""
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||
@ -58,7 +59,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
|
||||
f"the `arguments` JSON does not contain `actions` key."
|
||||
)
|
||||
|
||||
final_tools: List[AgentAction] = []
|
||||
final_tools: list[AgentAction] = []
|
||||
for tool_schema in tools:
|
||||
if "action" in tool_schema:
|
||||
_tool_input = tool_schema["action"]
|
||||
@ -112,7 +113,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
tools: Sequence[BaseTool]
|
||||
prompt: BasePromptTemplate
|
||||
|
||||
def get_allowed_tools(self) -> List[str]:
|
||||
def get_allowed_tools(self) -> list[str]:
|
||||
"""Get allowed tools."""
|
||||
return [t.name for t in self.tools]
|
||||
|
||||
@ -127,12 +128,12 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Get input keys. Input refers to user input here."""
|
||||
return ["input"]
|
||||
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
def functions(self) -> list[dict]:
|
||||
"""Get the functions for the agent."""
|
||||
enum_vals = [t.name for t in self.tools]
|
||||
tool_selection = {
|
||||
@ -194,10 +195,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -224,10 +225,10 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Async given input, decided what to do.
|
||||
|
||||
Args:
|
||||
@ -258,7 +259,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
"""Create prompt for this agent.
|
||||
|
||||
@ -272,7 +273,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
A prompt template to pass into this agent.
|
||||
"""
|
||||
_prompts = extra_prompt_messages or []
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
messages: list[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
if system_message:
|
||||
messages = [system_message]
|
||||
else:
|
||||
@ -293,7 +294,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -77,7 +77,7 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
|
||||
)
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.messages import BaseMessage
|
||||
@ -15,12 +15,12 @@ OpenAIToolAgentAction = ToolAgentAction
|
||||
|
||||
def parse_ai_message_to_openai_tool_action(
|
||||
message: BaseMessage,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message potentially containing tool_calls."""
|
||||
tool_actions = parse_ai_message_to_tool_action(message)
|
||||
if isinstance(tool_actions, AgentFinish):
|
||||
return tool_actions
|
||||
final_actions: List[AgentAction] = []
|
||||
final_actions: list[AgentAction] = []
|
||||
for action in tool_actions:
|
||||
if isinstance(action, ToolAgentAction):
|
||||
final_actions.append(
|
||||
@ -54,12 +54,12 @@ class OpenAIToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
return "openai-tools-agent-output-parser"
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_openai_tool_action(message)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Pattern, Union
|
||||
from re import Pattern
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -21,12 +21,12 @@ class ToolAgentAction(AgentActionMessageLog): # type: ignore[override]
|
||||
|
||||
def parse_ai_message_to_tool_action(
|
||||
message: BaseMessage,
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message potentially containing tool_calls."""
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message got {type(message)}")
|
||||
|
||||
actions: List = []
|
||||
actions: list = []
|
||||
if message.tool_calls:
|
||||
tool_calls = message.tool_calls
|
||||
else:
|
||||
@ -91,12 +91,12 @@ class ToolsAgentOutputParser(MultiActionAgentOutputParser):
|
||||
return "tools-agent-output-parser"
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> Union[List[AgentAction], AgentFinish]:
|
||||
self, result: list[Generation], *, partial: bool = False
|
||||
) -> Union[list[AgentAction], AgentFinish]:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
return parse_ai_message_to_tool_action(message)
|
||||
|
||||
def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]:
|
||||
def parse(self, text: str) -> Union[list[AgentAction], AgentFinish]:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
@ -20,7 +21,7 @@ def create_react_agent(
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
tools_renderer: ToolsRenderer = render_text_description,
|
||||
*,
|
||||
stop_sequence: Union[bool, List[str]] = True,
|
||||
stop_sequence: Union[bool, list[str]] = True,
|
||||
) -> Runnable:
|
||||
"""Create an agent that uses ReAct prompting.
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
@ -65,7 +66,7 @@ class ReActDocstoreAgent(Agent):
|
||||
return "Observation: "
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
def _stop(self) -> list[str]:
|
||||
return ["\nObservation:"]
|
||||
|
||||
@property
|
||||
@ -122,7 +123,7 @@ class DocstoreExplorer:
|
||||
return self._paragraphs[0]
|
||||
|
||||
@property
|
||||
def _paragraphs(self) -> List[str]:
|
||||
def _paragraphs(self) -> list[str]:
|
||||
if self.document is None:
|
||||
raise ValueError("Cannot get paragraphs without a document")
|
||||
return self.document.page_content.split("\n\n")
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate
|
||||
@ -12,7 +12,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
||||
return False
|
||||
|
||||
def _construct_agent_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
if len(intermediate_steps) == 0:
|
||||
return ""
|
||||
@ -26,7 +26,7 @@ class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
||||
f"you return as final answer):\n{thoughts}"
|
||||
)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
|
||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
|
||||
intermediate_steps
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
@ -1,5 +1,6 @@
|
||||
import re
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction
|
||||
@ -49,7 +50,7 @@ class StructuredChatAgent(Agent):
|
||||
return "Thought:"
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
@ -74,7 +75,7 @@ class StructuredChatAgent(Agent):
|
||||
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
|
||||
|
||||
@property
|
||||
def _stop(self) -> List[str]:
|
||||
def _stop(self) -> list[str]:
|
||||
return ["Observation:"]
|
||||
|
||||
@classmethod
|
||||
@ -85,8 +86,8 @@ class StructuredChatAgent(Agent):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
@ -117,8 +118,8 @@ class StructuredChatAgent(Agent):
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
@ -157,7 +158,7 @@ def create_structured_chat_agent(
|
||||
prompt: ChatPromptTemplate,
|
||||
tools_renderer: ToolsRenderer = render_text_description_and_args,
|
||||
*,
|
||||
stop_sequence: Union[bool, List[str]] = True,
|
||||
stop_sequence: Union[bool, list[str]] = True,
|
||||
) -> Runnable:
|
||||
"""Create an agent aimed at supporting tools with multiple inputs.
|
||||
|
||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Pattern, Union
|
||||
from re import Pattern
|
||||
from typing import Optional, Union
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Callable, List, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable
|
||||
|
||||
from langchain_core.agents import AgentAction
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -12,7 +13,7 @@ from langchain.agents.format_scratchpad.tools import (
|
||||
)
|
||||
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
|
||||
|
||||
MessageFormatter = Callable[[Sequence[Tuple[AgentAction, str]]], List[BaseMessage]]
|
||||
MessageFormatter = Callable[[Sequence[tuple[AgentAction, str]]], list[BaseMessage]]
|
||||
|
||||
|
||||
def create_tool_calling_agent(
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Interface for tools."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
@ -20,7 +20,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
|
||||
def _run(
|
||||
self,
|
||||
requested_tool_name: str,
|
||||
available_tool_names: List[str],
|
||||
available_tool_names: list[str],
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
@ -33,7 +33,7 @@ class InvalidTool(BaseTool): # type: ignore[override]
|
||||
async def _arun(
|
||||
self,
|
||||
requested_tool_name: str,
|
||||
available_tool_names: List[str],
|
||||
available_tool_names: list[str],
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Type, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.agent import BaseSingleActionAgent
|
||||
from langchain.agents.agent_types import AgentType
|
||||
@ -12,9 +12,9 @@ from langchain.agents.react.base import ReActDocstoreAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.structured_chat.base import StructuredChatAgent
|
||||
|
||||
AGENT_TYPE = Union[Type[BaseSingleActionAgent], Type[OpenAIMultiFunctionsAgent]]
|
||||
AGENT_TYPE = Union[type[BaseSingleActionAgent], type[OpenAIMultiFunctionsAgent]]
|
||||
|
||||
AGENT_TO_CLASS: Dict[AgentType, AGENT_TYPE] = {
|
||||
AGENT_TO_CLASS: dict[AgentType, AGENT_TYPE] = {
|
||||
AgentType.ZERO_SHOT_REACT_DESCRIPTION: ZeroShotAgent,
|
||||
AgentType.REACT_DOCSTORE: ReActDocstoreAgent,
|
||||
AgentType.SELF_ASK_WITH_SEARCH: SelfAskWithSearchAgent,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, List, Sequence, Tuple, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
@ -38,13 +39,13 @@ class XMLAgent(BaseSingleActionAgent):
|
||||
|
||||
"""
|
||||
|
||||
tools: List[BaseTool]
|
||||
tools: list[BaseTool]
|
||||
"""List of tools this agent has access to."""
|
||||
llm_chain: LLMChain
|
||||
"""Chain to use to predict action."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
return ["input"]
|
||||
|
||||
@staticmethod
|
||||
@ -60,7 +61,7 @@ class XMLAgent(BaseSingleActionAgent):
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -84,7 +85,7 @@ class XMLAgent(BaseSingleActionAgent):
|
||||
|
||||
async def aplan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
@ -113,7 +114,7 @@ def create_xml_agent(
|
||||
prompt: BasePromptTemplate,
|
||||
tools_renderer: ToolsRenderer = render_text_description,
|
||||
*,
|
||||
stop_sequence: Union[bool, List[str]] = True,
|
||||
stop_sequence: Union[bool, list[str]] = True,
|
||||
) -> Runnable:
|
||||
"""Create an agent that uses XML to format its logic.
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Literal, Union, cast
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
from langchain_core.outputs import LLMResult
|
||||
@ -25,7 +26,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||||
self.done = asyncio.Event()
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
# If two calls are made in a row, this resets the state
|
||||
self.done.clear()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
@ -30,7 +30,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
answer_prefix_tokens: Optional[List[str]] = None,
|
||||
answer_prefix_tokens: Optional[list[str]] = None,
|
||||
strip_tokens: bool = True,
|
||||
stream_prefix: bool = False,
|
||||
) -> None:
|
||||
@ -62,7 +62,7 @@ class AsyncFinalIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
self.answer_reached = False
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
# If two calls are made in a row, this resets the state
|
||||
self.done.clear()
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Callback Handler streams to stdout on new llm token."""
|
||||
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
||||
|
||||
@ -31,7 +31,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
answer_prefix_tokens: Optional[List[str]] = None,
|
||||
answer_prefix_tokens: Optional[list[str]] = None,
|
||||
strip_tokens: bool = True,
|
||||
stream_prefix: bool = False,
|
||||
) -> None:
|
||||
@ -63,7 +63,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
self.answer_reached = False
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.answer_reached = False
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@ -20,7 +21,7 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
|
||||
def _extract_scheme_and_domain(url: str) -> tuple[str, str]:
|
||||
"""Extract the scheme + domain from a given URL.
|
||||
|
||||
Args:
|
||||
@ -215,7 +216,7 @@ try:
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -223,7 +224,7 @@ try:
|
||||
return [self.question_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
@ -243,7 +244,7 @@ try:
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_limit_to_domains(cls, values: Dict) -> Any:
|
||||
def validate_limit_to_domains(cls, values: dict) -> Any:
|
||||
"""Check that allowed domains are valid."""
|
||||
# This check must be a pre=True check, so that a default of None
|
||||
# won't be set to limit_to_domains if it's not provided.
|
||||
@ -275,9 +276,9 @@ try:
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.question_key]
|
||||
api_url = self.api_request_chain.predict(
|
||||
@ -308,9 +309,9 @@ try:
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = (
|
||||
run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
|
@ -1,12 +1,13 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
import builtins
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import yaml
|
||||
from langchain_core._api import deprecated
|
||||
@ -46,7 +47,7 @@ def _get_verbosity() -> bool:
|
||||
return get_verbose()
|
||||
|
||||
|
||||
class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
class Chain(RunnableSerializable[dict[str, Any], dict[str, Any]], ABC):
|
||||
"""Abstract base class for creating structured sequences of calls to components.
|
||||
|
||||
Chains should be used to encode a sequence of calls to components like
|
||||
@ -86,13 +87,13 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
"""Whether or not run in verbose mode. In verbose mode, some intermediate logs
|
||||
will be printed to the console. Defaults to the global `verbose` value,
|
||||
accessible via `langchain.globals.get_verbose()`."""
|
||||
tags: Optional[List[str]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
"""Optional list of tags associated with the chain. Defaults to None.
|
||||
These tags will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None.
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
@ -107,7 +108,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||
@ -115,7 +116,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||
@ -123,10 +124,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
@ -162,7 +163,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
else self._call(inputs)
|
||||
)
|
||||
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
final_outputs: dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
@ -176,10 +177,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Dict[str, Any],
|
||||
input: dict[str, Any],
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
config = ensure_config(config)
|
||||
callbacks = config.get("callbacks")
|
||||
tags = config.get("tags")
|
||||
@ -213,7 +214,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
if new_arg_supported
|
||||
else await self._acall(inputs)
|
||||
)
|
||||
final_outputs: Dict[str, Any] = await self.aprep_outputs(
|
||||
final_outputs: dict[str, Any] = await self.aprep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
except BaseException as e:
|
||||
@ -231,7 +232,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_callback_manager_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_callback_manager_deprecation(cls, values: dict) -> Any:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
if values.get("callbacks") is not None:
|
||||
@ -261,15 +262,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Keys expected to be in the chain input."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Keys expected to be in the chain output."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
def _validate_inputs(self, inputs: dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
@ -289,7 +290,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
@ -297,9 +298,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
@ -319,9 +320,9 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
This is a private method that is not user-facing. It is only called within
|
||||
@ -345,15 +346,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
@deprecated("0.1.0", alternative="invoke", removal="1.0")
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the chain.
|
||||
|
||||
Args:
|
||||
@ -396,15 +397,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
@deprecated("0.1.0", alternative="ainvoke", removal="1.0")
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
inputs: Union[dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
run_name: Optional[str] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Asynchronously execute the chain.
|
||||
|
||||
Args:
|
||||
@ -445,10 +446,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
outputs: dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
@ -471,10 +472,10 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
async def aprep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
outputs: dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Validate and prepare chain outputs, and save info about this run to memory.
|
||||
|
||||
Args:
|
||||
@ -495,7 +496,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
@ -519,7 +520,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
inputs = dict(inputs, **external_context)
|
||||
return inputs
|
||||
|
||||
async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
async def aprep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||
"""Prepare chain inputs, including adding inputs from memory.
|
||||
|
||||
Args:
|
||||
@ -557,8 +558,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
@ -628,8 +629,8 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Convenience method for executing chain.
|
||||
@ -695,7 +696,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
def dict(self, **kwargs: Any) -> dict:
|
||||
"""Dictionary representation of chain.
|
||||
|
||||
Expects `Chain._chain_type` property to be implemented and for memory to be
|
||||
@ -763,7 +764,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
|
||||
|
||||
@deprecated("0.1.0", alternative="batch", removal="1.0")
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
self, input_list: list[builtins.dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> list[builtins.dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -47,22 +47,22 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsInput",
|
||||
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
||||
**{self.input_key: (list[Document], None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return create_model(
|
||||
"CombineDocumentsOutput",
|
||||
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -70,14 +70,14 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
This can be used by a caller to determine whether passing in a list
|
||||
@ -96,7 +96,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||
def combine_docs(self, docs: list[Document], **kwargs: Any) -> tuple[str, dict]:
|
||||
"""Combine documents into a single string.
|
||||
|
||||
Args:
|
||||
@ -111,8 +111,8 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine documents into a single string.
|
||||
|
||||
Args:
|
||||
@ -127,9 +127,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
inputs: dict[str, list[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
@ -143,9 +143,9 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
inputs: dict[str, list[Document]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Prepare inputs, call combine docs, prepare outputs."""
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
docs = inputs[self.input_key]
|
||||
@ -229,7 +229,7 @@ class AnalyzeDocumentChain(Chain):
|
||||
combine_docs_chain: BaseCombineDocumentsChain
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -237,7 +237,7 @@ class AnalyzeDocumentChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -246,7 +246,7 @@ class AnalyzeDocumentChain(Chain):
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return create_model(
|
||||
"AnalyzeDocumentChain",
|
||||
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||
@ -254,20 +254,20 @@ class AnalyzeDocumentChain(Chain):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return self.combine_docs_chain.get_output_schema(config)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Split document into chunks and pass to CombineDocumentsChain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
document = inputs[self.input_key]
|
||||
docs = self.text_splitter.create_documents([document])
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys: dict = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
other_keys[self.combine_docs_chain.input_key] = docs
|
||||
return self.combine_docs_chain(
|
||||
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -113,20 +113,20 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
if self.return_intermediate_steps:
|
||||
return create_model(
|
||||
"MapReduceDocumentsOutput",
|
||||
**{
|
||||
self.output_key: (str, None),
|
||||
"intermediate_steps": (List[str], None),
|
||||
"intermediate_steps": (list[str], None),
|
||||
}, # type: ignore[call-overload]
|
||||
)
|
||||
|
||||
return super().get_output_schema(config)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -143,7 +143,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_reduce_chain(cls, values: Dict) -> Any:
|
||||
def get_reduce_chain(cls, values: dict) -> Any:
|
||||
"""For backwards compatibility."""
|
||||
if "combine_document_chain" in values:
|
||||
if "reduce_documents_chain" in values:
|
||||
@ -167,7 +167,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
||||
def get_return_intermediate_steps(cls, values: dict) -> Any:
|
||||
"""For backwards compatibility."""
|
||||
if "return_map_steps" in values:
|
||||
values["return_intermediate_steps"] = values["return_map_steps"]
|
||||
@ -176,7 +176,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
||||
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||
"""Get default document variable name, if not provided."""
|
||||
if "llm_chain" not in values:
|
||||
raise ValueError("llm_chain must be provided")
|
||||
@ -227,11 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
@ -258,11 +258,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
async def acombine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine documents in a map reduce manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reducing the results.
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -79,7 +80,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Key in output of llm_chain to rank on."""
|
||||
answer_key: str
|
||||
"""Key in output of llm_chain to return as answer."""
|
||||
metadata_keys: Optional[List[str]] = None
|
||||
metadata_keys: Optional[list[str]] = None
|
||||
"""Additional metadata from the chosen document to return."""
|
||||
return_intermediate_steps: bool = False
|
||||
"""Return intermediate steps.
|
||||
@ -92,19 +93,19 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
schema: Dict[str, Any] = {
|
||||
) -> type[BaseModel]:
|
||||
schema: dict[str, Any] = {
|
||||
self.output_key: (str, None),
|
||||
}
|
||||
if self.return_intermediate_steps:
|
||||
schema["intermediate_steps"] = (List[str], None)
|
||||
schema["intermediate_steps"] = (list[str], None)
|
||||
if self.metadata_keys:
|
||||
schema.update({key: (Any, None) for key in self.metadata_keys})
|
||||
|
||||
return create_model("MapRerankOutput", **schema)
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -140,7 +141,7 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
||||
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||
"""Get default document variable name, if not provided."""
|
||||
if "llm_chain" not in values:
|
||||
raise ValueError("llm_chain must be provided")
|
||||
@ -163,8 +164,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
@ -187,8 +188,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self._process_results(docs, results)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine documents in a map rerank manner.
|
||||
|
||||
Combine by mapping first chain over all documents, then reranking the results.
|
||||
@ -212,10 +213,10 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
docs: List[Document],
|
||||
results: Sequence[Union[str, List[str], Dict[str, str]]],
|
||||
) -> Tuple[str, dict]:
|
||||
typed_results = cast(List[dict], results)
|
||||
docs: list[Document],
|
||||
results: Sequence[Union[str, list[str], dict[str, str]]],
|
||||
) -> tuple[str, dict]:
|
||||
typed_results = cast(list[dict], results)
|
||||
sorted_res = sorted(
|
||||
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
from typing import Any, Callable, Optional, Protocol
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -15,20 +15,20 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
class CombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
def __call__(self, docs: list[Document], **kwargs: Any) -> str:
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
|
||||
class AsyncCombineDocsProtocol(Protocol):
|
||||
"""Interface for the combine_docs method."""
|
||||
|
||||
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
||||
async def __call__(self, docs: list[Document], **kwargs: Any) -> str:
|
||||
"""Async interface for the combine_docs method."""
|
||||
|
||||
|
||||
def split_list_of_docs(
|
||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
) -> List[List[Document]]:
|
||||
docs: list[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||
) -> list[list[Document]]:
|
||||
"""Split Documents into subsets that each meet a cumulative length constraint.
|
||||
|
||||
Args:
|
||||
@ -59,7 +59,7 @@ def split_list_of_docs(
|
||||
|
||||
|
||||
def collapse_docs(
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
combine_document_func: CombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
@ -91,7 +91,7 @@ def collapse_docs(
|
||||
|
||||
|
||||
async def acollapse_docs(
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
combine_document_func: AsyncCombineDocsProtocol,
|
||||
**kwargs: Any,
|
||||
) -> Document:
|
||||
@ -229,11 +229,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine multiple documents recursively.
|
||||
|
||||
Args:
|
||||
@ -258,11 +258,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
async def acombine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
) -> tuple[str, dict]:
|
||||
"""Async combine multiple documents recursively.
|
||||
|
||||
Args:
|
||||
@ -287,16 +287,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
def _collapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
) -> tuple[list[Document], dict]:
|
||||
result_docs = docs
|
||||
length_func = self.combine_documents_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
||||
return self._collapse_chain.run(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
@ -322,16 +322,16 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
async def _acollapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
docs: list[Document],
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
) -> tuple[list[Document], dict]:
|
||||
result_docs = docs
|
||||
length_func = self.combine_documents_chain.prompt_length
|
||||
num_tokens = length_func(result_docs, **kwargs)
|
||||
|
||||
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
||||
async def _collapse_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
||||
return await self._collapse_chain.arun(
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -98,7 +98,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Return the results of the refine steps in the output."""
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -115,7 +115,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_return_intermediate_steps(cls, values: Dict) -> Any:
|
||||
def get_return_intermediate_steps(cls, values: dict) -> Any:
|
||||
"""For backwards compatibility."""
|
||||
if "return_refine_steps" in values:
|
||||
values["return_intermediate_steps"] = values["return_refine_steps"]
|
||||
@ -124,7 +124,7 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
||||
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||
"""Get default document variable name, if not provided."""
|
||||
if "initial_llm_chain" not in values:
|
||||
raise ValueError("initial_llm_chain must be provided")
|
||||
@ -147,8 +147,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Combine by mapping first chain over all, then stuffing into final chain.
|
||||
|
||||
Args:
|
||||
@ -172,8 +172,8 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Async combine by mapping a first chain over all, then stuffing
|
||||
into a final chain.
|
||||
|
||||
@ -197,22 +197,22 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
|
||||
refine_steps.append(res)
|
||||
return self._construct_result(refine_steps, res)
|
||||
|
||||
def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
|
||||
def _construct_result(self, refine_steps: list[str], res: str) -> tuple[str, dict]:
|
||||
if self.return_intermediate_steps:
|
||||
extra_return_dict = {"intermediate_steps": refine_steps}
|
||||
else:
|
||||
extra_return_dict = {}
|
||||
return res, extra_return_dict
|
||||
|
||||
def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
|
||||
def _construct_refine_inputs(self, doc: Document, res: str) -> dict[str, Any]:
|
||||
return {
|
||||
self.document_variable_name: format_document(doc, self.document_prompt),
|
||||
self.initial_response_name: res,
|
||||
}
|
||||
|
||||
def _construct_initial_inputs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
self, docs: list[Document], **kwargs: Any
|
||||
) -> dict[str, Any]:
|
||||
base_info = {"page_content": docs[0].page_content}
|
||||
base_info.update(docs[0].metadata)
|
||||
document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import Callbacks
|
||||
@ -29,7 +29,7 @@ def create_stuff_documents_chain(
|
||||
document_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
|
||||
document_variable_name: str = DOCUMENTS_KEY,
|
||||
) -> Runnable[Dict[str, Any], Any]:
|
||||
) -> Runnable[dict[str, Any], Any]:
|
||||
"""Create a chain for passing a list of Documents to a model.
|
||||
|
||||
Args:
|
||||
@ -163,7 +163,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Any:
|
||||
def get_default_document_variable_name(cls, values: dict) -> Any:
|
||||
"""Get default document variable name, if not provided.
|
||||
|
||||
If only one variable is present in the llm_chain.prompt,
|
||||
@ -188,13 +188,13 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
extra_keys = [
|
||||
k for k in self.llm_chain.input_keys if k != self.document_variable_name
|
||||
]
|
||||
return super().input_keys + extra_keys
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
def _get_inputs(self, docs: list[Document], **kwargs: Any) -> dict:
|
||||
"""Construct inputs from kwargs and docs.
|
||||
|
||||
Format and then join all the documents together into one input with name
|
||||
@ -220,7 +220,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
|
||||
return inputs
|
||||
|
||||
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
|
||||
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
This can be used by a caller to determine whether passing in a list
|
||||
@ -241,8 +241,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self.llm_chain._get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM.
|
||||
|
||||
Args:
|
||||
@ -259,8 +259,8 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
async def acombine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
self, docs: list[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> tuple[str, dict]:
|
||||
"""Async stuff all documents into one prompt and pass to LLM.
|
||||
|
||||
Args:
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Chain for applying constitutional principles to the outputs of another chain."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
@ -190,15 +190,15 @@ class ConstitutionalChain(Chain):
|
||||
""" # noqa: E501
|
||||
|
||||
chain: LLMChain
|
||||
constitutional_principles: List[ConstitutionalPrinciple]
|
||||
constitutional_principles: list[ConstitutionalPrinciple]
|
||||
critique_chain: LLMChain
|
||||
revision_chain: LLMChain
|
||||
return_intermediate_steps: bool = False
|
||||
|
||||
@classmethod
|
||||
def get_principles(
|
||||
cls, names: Optional[List[str]] = None
|
||||
) -> List[ConstitutionalPrinciple]:
|
||||
cls, names: Optional[list[str]] = None
|
||||
) -> list[ConstitutionalPrinciple]:
|
||||
if names is None:
|
||||
return list(PRINCIPLES.values())
|
||||
else:
|
||||
@ -224,12 +224,12 @@ class ConstitutionalChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys."""
|
||||
return self.chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Output keys."""
|
||||
if self.return_intermediate_steps:
|
||||
return ["output", "critiques_and_revisions", "initial_output"]
|
||||
@ -237,9 +237,9 @@ class ConstitutionalChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
response = self.chain.run(
|
||||
**inputs,
|
||||
@ -305,7 +305,7 @@ class ConstitutionalChain(Chain):
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
final_output: Dict[str, Any] = {"output": response}
|
||||
final_output: dict[str, Any] = {"output": response}
|
||||
if self.return_intermediate_steps:
|
||||
final_output["initial_output"] = initial_response
|
||||
final_output["critiques_and_revisions"] = critiques_and_revisions
|
||||
|
@ -1,7 +1,5 @@
|
||||
"""Chain that carries on a conversation and calls an LLM."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
@ -121,7 +119,7 @@ class ConversationChain(LLMChain): # type: ignore[override, override]
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Use this since so some prompt vars come from history."""
|
||||
return [self.input_key]
|
||||
|
||||
|
@ -6,7 +6,7 @@ import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -32,13 +32,13 @@ from langchain.chains.question_answering import load_qa_chain
|
||||
|
||||
# Depending on the memory type and configuration, the chat history format may differ.
|
||||
# This needs to be consolidated.
|
||||
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
|
||||
CHAT_TURN_TYPE = Union[tuple[str, str], BaseMessage]
|
||||
|
||||
|
||||
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
|
||||
|
||||
|
||||
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
|
||||
def _get_chat_history(chat_history: list[CHAT_TURN_TYPE]) -> str:
|
||||
buffer = ""
|
||||
for dialogue_turn in chat_history:
|
||||
if isinstance(dialogue_turn, BaseMessage):
|
||||
@ -64,7 +64,7 @@ class InputType(BaseModel):
|
||||
|
||||
question: str
|
||||
"""The question to answer."""
|
||||
chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list)
|
||||
chat_history: list[CHAT_TURN_TYPE] = Field(default_factory=list)
|
||||
"""The chat history to use for retrieval."""
|
||||
|
||||
|
||||
@ -89,7 +89,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
"""Return the retrieved source documents as part of the final result."""
|
||||
return_generated_question: bool = False
|
||||
"""Return the generated question as part of the final result."""
|
||||
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
|
||||
get_chat_history: Optional[Callable[[list[CHAT_TURN_TYPE]], str]] = None
|
||||
"""An optional function to get a string of the chat history.
|
||||
If None is provided, will use a default."""
|
||||
response_if_no_docs_found: Optional[str] = None
|
||||
@ -103,17 +103,17 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys."""
|
||||
return ["question", "chat_history"]
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
) -> type[BaseModel]:
|
||||
return InputType
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return the output keys.
|
||||
|
||||
:meta private:
|
||||
@ -129,17 +129,17 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
@ -159,7 +159,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
||||
else:
|
||||
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
||||
output: Dict[str, Any] = {}
|
||||
output: dict[str, Any] = {}
|
||||
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
||||
output[self.output_key] = self.response_if_no_docs_found
|
||||
else:
|
||||
@ -182,17 +182,17 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs["question"]
|
||||
get_chat_history = self.get_chat_history or _get_chat_history
|
||||
@ -212,7 +212,7 @@ class BaseConversationalRetrievalChain(Chain):
|
||||
else:
|
||||
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
|
||||
|
||||
output: Dict[str, Any] = {}
|
||||
output: dict[str, Any] = {}
|
||||
if self.response_if_no_docs_found is not None and len(docs) == 0:
|
||||
output[self.output_key] = self.response_if_no_docs_found
|
||||
else:
|
||||
@ -368,7 +368,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
"""If set, enforces that the documents returned are less than this limit.
|
||||
This is only enforced if `combine_docs_chain` is of type StuffDocumentsChain."""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.max_tokens_limit and isinstance(
|
||||
@ -388,10 +388,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
docs = self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
@ -401,10 +401,10 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
docs = await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
@ -420,7 +420,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
chain_type: str = "stuff",
|
||||
verbose: bool = False,
|
||||
condense_question_llm: Optional[BaseLanguageModel] = None,
|
||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
||||
combine_docs_chain_kwargs: Optional[dict] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
@ -485,7 +485,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
warnings.warn(
|
||||
"`ChatVectorDBChain` is deprecated - "
|
||||
"please use `from langchain.chains import ConversationalRetrievalChain`"
|
||||
@ -495,10 +495,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
vectordbkwargs = inputs.get("vectordbkwargs", {})
|
||||
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
|
||||
@ -509,10 +509,10 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
raise NotImplementedError("ChatVectorDBChain does not support async")
|
||||
|
||||
@ -523,7 +523,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
vectorstore: VectorStore,
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
chain_type: str = "stuff",
|
||||
combine_docs_chain_kwargs: Optional[Dict] = None,
|
||||
combine_docs_chain_kwargs: Optional[dict] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseConversationalRetrievalChain:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -44,8 +44,8 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
"""Elasticsearch database to connect to of type elasticsearch.Elasticsearch."""
|
||||
top_k: int = 10
|
||||
"""Number of results to return from the query"""
|
||||
ignore_indices: Optional[List[str]] = None
|
||||
include_indices: Optional[List[str]] = None
|
||||
ignore_indices: Optional[list[str]] = None
|
||||
include_indices: Optional[list[str]] = None
|
||||
input_key: str = "question" #: :meta private:
|
||||
output_key: str = "result" #: :meta private:
|
||||
sample_documents_in_index_info: int = 3
|
||||
@ -66,7 +66,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
@ -74,7 +74,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
@ -84,7 +84,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
else:
|
||||
return [self.output_key, INTERMEDIATE_STEPS_KEY]
|
||||
|
||||
def _list_indices(self) -> List[str]:
|
||||
def _list_indices(self) -> list[str]:
|
||||
all_indices = [
|
||||
index["index"] for index in self.database.cat.indices(format="json")
|
||||
]
|
||||
@ -96,7 +96,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
|
||||
return all_indices
|
||||
|
||||
def _get_indices_infos(self, indices: List[str]) -> str:
|
||||
def _get_indices_infos(self, indices: list[str]) -> str:
|
||||
mappings = self.database.indices.get_mapping(index=",".join(indices))
|
||||
if self.sample_documents_in_index_info > 0:
|
||||
for k, v in mappings.items():
|
||||
@ -114,15 +114,15 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
]
|
||||
)
|
||||
|
||||
def _search(self, indices: List[str], query: str) -> str:
|
||||
def _search(self, indices: list[str], query: str) -> str:
|
||||
result = self.database.search(index=",".join(indices), body=query)
|
||||
return str(result)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
input_text = f"{inputs[self.input_key]}\nESQuery:"
|
||||
_run_manager.on_text(input_text, verbose=self.verbose)
|
||||
@ -134,7 +134,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
"indices_info": indices_info,
|
||||
"stop": ["\nESResult:"],
|
||||
}
|
||||
intermediate_steps: List = []
|
||||
intermediate_steps: list = []
|
||||
try:
|
||||
intermediate_steps.append(query_inputs) # input: es generation
|
||||
es_cmd = self.query_chain.invoke(
|
||||
@ -163,7 +163,7 @@ class ElasticsearchDatabaseChain(Chain):
|
||||
|
||||
intermediate_steps.append(final_result) # output: final answer
|
||||
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
|
||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||
chain_result: dict[str, Any] = {self.output_key: final_result}
|
||||
if self.return_intermediate_steps:
|
||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||
return chain_result
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||
@ -9,7 +7,7 @@ TEST_GEN_TEMPLATE_SUFFIX = "Add another example."
|
||||
|
||||
|
||||
def generate_example(
|
||||
examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
|
||||
examples: list[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate
|
||||
) -> str:
|
||||
"""Return another example given a list of examples for a prompt."""
|
||||
prompt = FewShotPromptTemplate(
|
||||
|
@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
@ -26,7 +27,7 @@ from langchain.chains.llm import LLMChain
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
|
||||
def _extract_tokens_and_log_probs(response: AIMessage) -> tuple[list[str], list[float]]:
|
||||
"""Extract tokens and log probabilities from chat model response."""
|
||||
tokens = []
|
||||
log_probs = []
|
||||
@ -47,7 +48,7 @@ class QuestionGeneratorChain(LLMChain):
|
||||
return False
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input", "context", "response"]
|
||||
|
||||
@ -58,7 +59,7 @@ def _low_confidence_spans(
|
||||
min_prob: float,
|
||||
min_token_gap: int,
|
||||
num_pad_tokens: int,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
@ -117,22 +118,22 @@ class FlareChain(Chain):
|
||||
"""Whether to start with retrieval."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for the chain."""
|
||||
return ["user_input"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Output keys for the chain."""
|
||||
return ["response"]
|
||||
|
||||
def _do_generation(
|
||||
self,
|
||||
questions: List[str],
|
||||
questions: list[str],
|
||||
user_input: str,
|
||||
response: str,
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
) -> Tuple[str, bool]:
|
||||
) -> tuple[str, bool]:
|
||||
callbacks = _run_manager.get_child()
|
||||
docs = []
|
||||
for question in questions:
|
||||
@ -153,12 +154,12 @@ class FlareChain(Chain):
|
||||
|
||||
def _do_retrieval(
|
||||
self,
|
||||
low_confidence_spans: List[str],
|
||||
low_confidence_spans: list[str],
|
||||
_run_manager: CallbackManagerForChainRun,
|
||||
user_input: str,
|
||||
response: str,
|
||||
initial_response: str,
|
||||
) -> Tuple[str, bool]:
|
||||
) -> tuple[str, bool]:
|
||||
question_gen_inputs = [
|
||||
{
|
||||
"user_input": user_input,
|
||||
@ -187,9 +188,9 @@ class FlareChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
|
||||
user_input = inputs[self.input_keys[0]]
|
||||
|
@ -1,16 +1,14 @@
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
|
||||
class FinishedOutputParser(BaseOutputParser[Tuple[str, bool]]):
|
||||
class FinishedOutputParser(BaseOutputParser[tuple[str, bool]]):
|
||||
"""Output parser that checks if the output is finished."""
|
||||
|
||||
finished_value: str = "FINISHED"
|
||||
"""Value that indicates the output is finished."""
|
||||
|
||||
def parse(self, text: str) -> Tuple[str, bool]:
|
||||
def parse(self, text: str) -> tuple[str, bool]:
|
||||
cleaned = text.strip()
|
||||
finished = self.finished_value in cleaned
|
||||
return cleaned.replace(self.finished_value, ""), finished
|
||||
|
@ -6,7 +6,7 @@ https://arxiv.org/abs/2212.10496
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -38,23 +38,23 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys for Hyde's LLM chain."""
|
||||
return self.llm_chain.input_schema.model_json_schema()["required"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Output keys for Hyde's LLM chain."""
|
||||
if isinstance(self.llm_chain, LLMChain):
|
||||
return self.llm_chain.output_keys
|
||||
else:
|
||||
return ["text"]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Call the base embeddings."""
|
||||
return self.base_embeddings.embed_documents(texts)
|
||||
|
||||
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
|
||||
def combine_embeddings(self, embeddings: list[list[float]]) -> list[float]:
|
||||
"""Combine embeddings into final embeddings."""
|
||||
try:
|
||||
import numpy as np
|
||||
@ -73,7 +73,7 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
num_vectors = len(embeddings)
|
||||
return [sum(dim_values) / num_vectors for dim_values in zip(*embeddings)]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Generate a hypothetical document and embedded it."""
|
||||
var_name = self.input_keys[0]
|
||||
result = self.llm_chain.invoke({var_name: text})
|
||||
@ -86,9 +86,9 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""Call the internal llm chain."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
return self.llm_chain.invoke(
|
||||
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -100,7 +101,7 @@ class LLMChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
@ -108,7 +109,7 @@ class LLMChain(Chain):
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
@ -120,15 +121,15 @@ class LLMChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
@ -143,9 +144,9 @@ class LLMChain(Chain):
|
||||
)
|
||||
else:
|
||||
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
cast(list, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
generations: list[list[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
@ -155,7 +156,7 @@ class LLMChain(Chain):
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
@ -170,9 +171,9 @@ class LLMChain(Chain):
|
||||
)
|
||||
else:
|
||||
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
||||
cast(List, prompts), {"callbacks": callbacks}
|
||||
cast(list, prompts), {"callbacks": callbacks}
|
||||
)
|
||||
generations: List[List[Generation]] = []
|
||||
generations: list[list[Generation]] = []
|
||||
for res in results:
|
||||
if isinstance(res, BaseMessage):
|
||||
generations.append([ChatGeneration(message=res)])
|
||||
@ -182,9 +183,9 @@ class LLMChain(Chain):
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
) -> tuple[list[PromptValue], Optional[list[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
@ -208,9 +209,9 @@ class LLMChain(Chain):
|
||||
|
||||
async def aprep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
input_list: list[dict[str, Any]],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
) -> tuple[list[PromptValue], Optional[list[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if len(input_list) == 0:
|
||||
@ -233,8 +234,8 @@ class LLMChain(Chain):
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> list[dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
@ -254,8 +255,8 @@ class LLMChain(Chain):
|
||||
return outputs
|
||||
|
||||
async def aapply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> list[dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
@ -278,7 +279,7 @@ class LLMChain(Chain):
|
||||
def _run_output_key(self) -> str:
|
||||
return self.output_key
|
||||
|
||||
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
|
||||
def create_outputs(self, llm_result: LLMResult) -> list[dict[str, Any]]:
|
||||
"""Create outputs from response."""
|
||||
result = [
|
||||
# Get the text of the top generated string.
|
||||
@ -294,9 +295,9 @@ class LLMChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
response = await self.agenerate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
@ -336,7 +337,7 @@ class LLMChain(Chain):
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
) -> Union[str, list[str], dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The predict_and_parse method is deprecated, "
|
||||
@ -350,7 +351,7 @@ class LLMChain(Chain):
|
||||
|
||||
async def apredict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, str]]:
|
||||
) -> Union[str, list[str], dict[str, str]]:
|
||||
"""Call apredict and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apredict_and_parse method is deprecated, "
|
||||
@ -363,8 +364,8 @@ class LLMChain(Chain):
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The apply_and_parse method is deprecated, "
|
||||
@ -374,8 +375,8 @@ class LLMChain(Chain):
|
||||
return self._parse_generation(result)
|
||||
|
||||
def _parse_generation(
|
||||
self, generation: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
self, generation: list[dict[str, str]]
|
||||
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
@ -385,8 +386,8 @@ class LLMChain(Chain):
|
||||
return generation
|
||||
|
||||
async def aapply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
self, input_list: list[dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, list[str], dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
warnings.warn(
|
||||
"The aapply_and_parse method is deprecated, "
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
@ -107,7 +107,7 @@ class LLMCheckerChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
|
||||
@ -135,7 +135,7 @@ class LLMCheckerChain(Chain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
@ -143,7 +143,7 @@ class LLMCheckerChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
@ -152,9 +152,9 @@ class LLMCheckerChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -163,7 +163,7 @@ class LLMMathChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
try:
|
||||
import numexpr # noqa: F401
|
||||
except ImportError:
|
||||
@ -183,7 +183,7 @@ class LLMMathChain(Chain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -191,7 +191,7 @@ class LLMMathChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Expect output key.
|
||||
|
||||
:meta private:
|
||||
@ -221,7 +221,7 @@ class LLMMathChain(Chain):
|
||||
|
||||
def _process_llm_result(
|
||||
self, llm_output: str, run_manager: CallbackManagerForChainRun
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
@ -243,7 +243,7 @@ class LLMMathChain(Chain):
|
||||
self,
|
||||
llm_output: str,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
|
||||
llm_output = llm_output.strip()
|
||||
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
|
||||
@ -263,9 +263,9 @@ class LLMMathChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = self.llm_chain.predict(
|
||||
@ -277,9 +277,9 @@ class LLMMathChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
await _run_manager.on_text(inputs[self.input_key])
|
||||
llm_output = await self.llm_chain.apredict(
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
@ -112,7 +112,7 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
|
||||
@ -131,7 +131,7 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return the singular input key.
|
||||
|
||||
:meta private:
|
||||
@ -139,7 +139,7 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return the singular output key.
|
||||
|
||||
:meta private:
|
||||
@ -148,9 +148,9 @@ class LLMSummarizationCheckerChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
all_true = False
|
||||
count = 0
|
||||
|
@ -702,7 +702,7 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
||||
with open(file_path) as f:
|
||||
config = json.load(f)
|
||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||
with open(file_path, "r") as f:
|
||||
with open(file_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
else:
|
||||
raise ValueError("File type must be json or yaml")
|
||||
|
@ -6,7 +6,8 @@ then combines the results with another one.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun, Callbacks
|
||||
@ -84,7 +85,7 @@ class MapReduceChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -92,7 +93,7 @@ class MapReduceChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -101,15 +102,15 @@ class MapReduceChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
# Split the larger text into smaller chunks.
|
||||
doc_text = inputs.pop(self.input_key)
|
||||
texts = self.text_splitter.split_text(doc_text)
|
||||
docs = [Document(page_content=text) for text in texts]
|
||||
_inputs: Dict[str, Any] = {
|
||||
_inputs: dict[str, Any] = {
|
||||
**inputs,
|
||||
self.combine_documents_chain.input_key: docs,
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Pass input through a moderation endpoint."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -42,7 +42,7 @@ class OpenAIModerationChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_environment(cls, values: Dict) -> Any:
|
||||
def validate_environment(cls, values: dict) -> Any:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
openai_api_key = get_from_dict_or_env(
|
||||
values, "openai_api_key", "OPENAI_API_KEY"
|
||||
@ -78,7 +78,7 @@ class OpenAIModerationChain(Chain):
|
||||
return values
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -86,7 +86,7 @@ class OpenAIModerationChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -108,9 +108,9 @@ class OpenAIModerationChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
text = inputs[self.input_key]
|
||||
if self.openai_pre_1_0:
|
||||
results = self.client.create(text)
|
||||
@ -122,9 +122,9 @@ class OpenAIModerationChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if self.openai_pre_1_0:
|
||||
return await super()._acall(inputs, run_manager=run_manager)
|
||||
text = inputs[self.input_key]
|
||||
|
@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.caches import BaseCache as BaseCache
|
||||
@ -68,7 +68,7 @@ class NatBotChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
if "llm" in values:
|
||||
warnings.warn(
|
||||
"Directly instantiating an NatBotChain with an llm is deprecated. "
|
||||
@ -97,7 +97,7 @@ class NatBotChain(Chain):
|
||||
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect url and browser content.
|
||||
|
||||
:meta private:
|
||||
@ -105,7 +105,7 @@ class NatBotChain(Chain):
|
||||
return [self.input_url_key, self.input_browser_content_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return command.
|
||||
|
||||
:meta private:
|
||||
@ -114,9 +114,9 @@ class NatBotChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
|
@ -1,12 +1,10 @@
|
||||
"""Methods for creating chains that use OpenAI function-calling APIs."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -45,7 +43,7 @@ __all__ = [
|
||||
|
||||
@deprecated(since="0.1.1", removal="1.0", alternative="create_openai_fn_runnable")
|
||||
def create_openai_fn_chain(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
@ -128,7 +126,7 @@ def create_openai_fn_chain(
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
openai_functions = [convert_to_openai_function(f) for f in functions]
|
||||
output_parser = output_parser or get_openai_output_parser(functions)
|
||||
llm_kwargs: Dict[str, Any] = {
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"functions": openai_functions,
|
||||
}
|
||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||
@ -148,7 +146,7 @@ def create_openai_fn_chain(
|
||||
since="0.1.1", removal="1.0", alternative="ChatOpenAI.with_structured_output"
|
||||
)
|
||||
def create_structured_output_chain(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate,
|
||||
*,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Iterator, List
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
|
||||
@ -21,7 +21,7 @@ class FactWithEvidence(BaseModel):
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
substring_quote: list[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
@ -54,7 +54,7 @@ class QuestionAnswer(BaseModel):
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
answer: list[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -83,7 +83,7 @@ def create_extraction_chain(
|
||||
schema: dict,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
verbose: bool = False,
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
@ -170,7 +170,7 @@ def create_extraction_chain_pydantic(
|
||||
"""
|
||||
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
info: list[pydantic_schema] # type: ignore
|
||||
|
||||
if hasattr(pydantic_schema, "model_json_schema"):
|
||||
openai_schema = pydantic_schema.model_json_schema()
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import requests
|
||||
from langchain_core._api import deprecated
|
||||
@ -70,7 +70,7 @@ def _format_url(url: str, path_params: dict) -> str:
|
||||
return url.format(**new_params)
|
||||
|
||||
|
||||
def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -> dict:
|
||||
def _openapi_params_to_json_schema(params: list[Parameter], spec: OpenAPISpec) -> dict:
|
||||
properties = {}
|
||||
required = []
|
||||
for p in params:
|
||||
@ -89,7 +89,7 @@ def _openapi_params_to_json_schema(params: List[Parameter], spec: OpenAPISpec) -
|
||||
|
||||
def openapi_spec_to_openai_fn(
|
||||
spec: OpenAPISpec,
|
||||
) -> Tuple[List[Dict[str, Any]], Callable]:
|
||||
) -> tuple[list[dict[str, Any]], Callable]:
|
||||
"""Convert a valid OpenAPI spec to the JSON Schema format expected for OpenAI
|
||||
functions.
|
||||
|
||||
@ -208,18 +208,18 @@ class SimpleRequestChain(Chain):
|
||||
"""Key to use for the input of the request."""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
name = inputs[self.input_key].pop("name")
|
||||
@ -257,10 +257,10 @@ def get_openapi_chain(
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
request_chain: Optional[Chain] = None,
|
||||
llm_chain_kwargs: Optional[Dict] = None,
|
||||
llm_chain_kwargs: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
headers: Optional[Dict] = None,
|
||||
params: Optional[Dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
params: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> SequentialChain:
|
||||
"""Create a chain for querying an API from a OpenAPI spec.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, Type, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -21,7 +21,7 @@ class AnswerWithSources(BaseModel):
|
||||
"""An answer to the question, with sources."""
|
||||
|
||||
answer: str = Field(..., description="Answer to the question that was asked")
|
||||
sources: List[str] = Field(
|
||||
sources: list[str] = Field(
|
||||
..., description="List of sources used to answer the question"
|
||||
)
|
||||
|
||||
@ -37,7 +37,7 @@ class AnswerWithSources(BaseModel):
|
||||
)
|
||||
def create_qa_with_structure_chain(
|
||||
llm: BaseLanguageModel,
|
||||
schema: Union[dict, Type[BaseModel]],
|
||||
schema: Union[dict, type[BaseModel]],
|
||||
output_parser: str = "base",
|
||||
prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = None,
|
||||
verbose: bool = False,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
def _resolve_schema_references(schema: Any, definitions: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolve the $ref keys in a JSON schema object using the provided definitions.
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Type, Union
|
||||
from typing import Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -51,7 +51,7 @@ If a property is not present and is not required in the function parameters, do
|
||||
),
|
||||
)
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schemas: Union[List[Type[BaseModel]], Type[BaseModel]],
|
||||
pydantic_schemas: Union[list[type[BaseModel]], type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
system_message: str = _EXTRACTION_TEMPLATE,
|
||||
) -> Runnable:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@ -21,8 +21,8 @@ class ConditionalPromptSelector(BasePromptSelector):
|
||||
|
||||
default_prompt: BasePromptTemplate
|
||||
"""Default prompt to use if no conditionals match."""
|
||||
conditionals: List[
|
||||
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||
conditionals: list[
|
||||
tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||
] = Field(default_factory=list)
|
||||
"""List of conditionals and prompts to use if the conditionals match."""
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
@ -103,18 +103,18 @@ class QAGenerationChain(Chain):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, List]:
|
||||
) -> dict[str, list]:
|
||||
docs = self.text_splitter.create_documents([inputs[self.input_key]])
|
||||
results = self.llm_chain.generate(
|
||||
[{"text": d.page_content} for d in docs], run_manager=run_manager
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -103,7 +103,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -111,7 +111,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
return [self.question_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -123,13 +123,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_naming(cls, values: Dict) -> Any:
|
||||
def validate_naming(cls, values: dict) -> Any:
|
||||
"""Fix backwards compatibility in naming."""
|
||||
if "combine_document_chain" in values:
|
||||
values["combine_documents_chain"] = values.pop("combine_document_chain")
|
||||
return values
|
||||
|
||||
def _split_sources(self, answer: str) -> Tuple[str, str]:
|
||||
def _split_sources(self, answer: str) -> tuple[str, str]:
|
||||
"""Split sources from answer."""
|
||||
if re.search(r"SOURCES?:", answer, re.IGNORECASE):
|
||||
answer, sources = re.split(
|
||||
@ -143,17 +143,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
@abstractmethod
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
@ -167,7 +167,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
answer, sources = self._split_sources(answer)
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
}
|
||||
@ -178,17 +178,17 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
@abstractmethod
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
@ -201,7 +201,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
|
||||
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
||||
)
|
||||
answer, sources = self._split_sources(answer)
|
||||
result: Dict[str, Any] = {
|
||||
result: dict[str, Any] = {
|
||||
self.answer_key: answer,
|
||||
self.sources_answer_key: sources,
|
||||
}
|
||||
@ -225,7 +225,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
input_docs_key: str = "docs" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -234,19 +234,19 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs to run questioning over."""
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Question-answering with sources over an index."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -25,7 +25,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""Restrict the docs to return from store based on tokens,
|
||||
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.reduce_k_below_max_tokens and isinstance(
|
||||
@ -43,8 +43,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> list[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
@ -52,8 +52,8 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> list[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Question-answering with sources over a vector database."""
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -27,10 +27,10 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
max_tokens_limit: int = 3375
|
||||
"""Restrict the docs to return from store based on tokens,
|
||||
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
search_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Extra search args."""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
def _reduce_tokens_below_limit(self, docs: list[Document]) -> list[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.reduce_k_below_max_tokens and isinstance(
|
||||
@ -48,8 +48,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
self, inputs: dict[str, Any], *, run_manager: CallbackManagerForChainRun
|
||||
) -> list[Document]:
|
||||
question = inputs[self.question_key]
|
||||
docs = self.vectorstore.similarity_search(
|
||||
question, k=self.k, **self.search_kwargs
|
||||
@ -57,13 +57,13 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
async def _aget_docs(
|
||||
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> List[Document]:
|
||||
self, inputs: dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
warnings.warn(
|
||||
"`VectorDBQAWithSourcesChain` is deprecated - "
|
||||
"please use `from langchain.chains import RetrievalQAWithSourcesChain`"
|
||||
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
@ -172,7 +173,7 @@ def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
|
||||
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
|
||||
|
||||
|
||||
def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
|
||||
def construct_examples(input_output_pairs: Sequence[tuple[str, dict]]) -> list[dict]:
|
||||
"""Construct examples from input-output pairs.
|
||||
|
||||
Args:
|
||||
@ -267,7 +268,7 @@ def load_query_constructor_chain(
|
||||
llm: BaseLanguageModel,
|
||||
document_contents: str,
|
||||
attribute_info: Sequence[Union[AttributeInfo, dict]],
|
||||
examples: Optional[List] = None,
|
||||
examples: Optional[list] = None,
|
||||
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
|
||||
allowed_operators: Sequence[Operator] = tuple(Operator),
|
||||
enable_limit: bool = False,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import warnings
|
||||
from typing import Any, Literal, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from langchain_core.utils import check_package_version
|
||||
from typing_extensions import TypedDict
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Load question answering chains."""
|
||||
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import BaseCallbackManager, Callbacks
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from langchain_core.retrievers import (
|
||||
BaseRetriever,
|
||||
@ -11,7 +11,7 @@ from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
|
||||
def create_retrieval_chain(
|
||||
retriever: Union[BaseRetriever, Runnable[dict, RetrieverOutput]],
|
||||
combine_docs_chain: Runnable[Dict[str, Any], str],
|
||||
combine_docs_chain: Runnable[dict[str, Any], str],
|
||||
) -> Runnable:
|
||||
"""Create retrieval chain that retrieves documents and then passes them on.
|
||||
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -54,7 +54,7 @@ class BaseRetrievalQA(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Input keys.
|
||||
|
||||
:meta private:
|
||||
@ -62,7 +62,7 @@ class BaseRetrievalQA(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Output keys.
|
||||
|
||||
:meta private:
|
||||
@ -123,14 +123,14 @@ class BaseRetrievalQA(Chain):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
@ -166,14 +166,14 @@ class BaseRetrievalQA(Chain):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get documents to do question answering over."""
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run get_relevant_text and llm on input query.
|
||||
|
||||
If chain has 'return_source_documents' as 'True', returns
|
||||
@ -266,7 +266,7 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
return self.retriever.invoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
@ -277,7 +277,7 @@ class RetrievalQA(BaseRetrievalQA):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
return await self.retriever.ainvoke(
|
||||
question, config={"callbacks": run_manager.get_child()}
|
||||
@ -307,12 +307,12 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
"""Number of documents to query for."""
|
||||
search_type: str = "similarity"
|
||||
"""Search type to use over vectorstore. `similarity` or `mmr`."""
|
||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
search_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Extra search args."""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def raise_deprecation(cls, values: Dict) -> Any:
|
||||
def raise_deprecation(cls, values: dict) -> Any:
|
||||
warnings.warn(
|
||||
"`VectorDBQA` is deprecated - "
|
||||
"please use `from langchain.chains import RetrievalQA`"
|
||||
@ -321,7 +321,7 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_search_type(cls, values: Dict) -> Any:
|
||||
def validate_search_type(cls, values: dict) -> Any:
|
||||
"""Validate search type."""
|
||||
if "search_type" in values:
|
||||
search_type = values["search_type"]
|
||||
@ -334,7 +334,7 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(
|
||||
@ -353,7 +353,7 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
question: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
"""Get docs."""
|
||||
raise NotImplementedError("VectorDBQA does not support async")
|
||||
|
||||
|
@ -3,7 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -17,17 +18,17 @@ from langchain.chains.base import Chain
|
||||
|
||||
class Route(NamedTuple):
|
||||
destination: Optional[str]
|
||||
next_inputs: Dict[str, Any]
|
||||
next_inputs: dict[str, Any]
|
||||
|
||||
|
||||
class RouterChain(Chain, ABC):
|
||||
"""Chain that outputs the name of a destination chain and the inputs to it."""
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["destination", "next_inputs"]
|
||||
|
||||
def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||
def route(self, inputs: dict[str, Any], callbacks: Callbacks = None) -> Route:
|
||||
"""
|
||||
Route inputs to a destination chain.
|
||||
|
||||
@ -42,7 +43,7 @@ class RouterChain(Chain, ABC):
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
async def aroute(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
self, inputs: dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Route:
|
||||
result = await self.acall(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
@ -67,7 +68,7 @@ class MultiRouteChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the router chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
@ -75,7 +76,7 @@ class MultiRouteChain(Chain):
|
||||
return self.router_chain.input_keys
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
@ -84,9 +85,9 @@ class MultiRouteChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
route = self.router_chain.route(inputs, callbacks=callbacks)
|
||||
@ -109,9 +110,9 @@ class MultiRouteChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -18,7 +19,7 @@ class EmbeddingRouterChain(RouterChain):
|
||||
"""Chain that uses embeddings to route between options."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
routing_keys: List[str] = ["query"]
|
||||
routing_keys: list[str] = ["query"]
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
@ -26,7 +27,7 @@ class EmbeddingRouterChain(RouterChain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
@ -35,18 +36,18 @@ class EmbeddingRouterChain(RouterChain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||
results = self.vectorstore.similarity_search(_input, k=1)
|
||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_input = ", ".join([inputs[k] for k in self.routing_keys])
|
||||
results = await self.vectorstore.asimilarity_search(_input, k=1)
|
||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||
@ -54,8 +55,8 @@ class EmbeddingRouterChain(RouterChain):
|
||||
@classmethod
|
||||
def from_names_and_descriptions(
|
||||
cls,
|
||||
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingRouterChain:
|
||||
@ -72,8 +73,8 @@ class EmbeddingRouterChain(RouterChain):
|
||||
@classmethod
|
||||
async def afrom_names_and_descriptions(
|
||||
cls,
|
||||
names_and_descriptions: Sequence[Tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
names_and_descriptions: Sequence[tuple[str, Sequence[str]]],
|
||||
vectorstore_cls: type[VectorStore],
|
||||
embeddings: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingRouterChain:
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.callbacks import (
|
||||
@ -114,42 +114,42 @@ class LLMRouterChain(RouterChain):
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Will be whatever keys the LLM chain prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.input_keys
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
def _validate_outputs(self, outputs: dict[str, Any]) -> None:
|
||||
super()._validate_outputs(outputs)
|
||||
if not isinstance(outputs["next_inputs"], dict):
|
||||
raise ValueError
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
|
||||
prediction = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
self.llm_chain.prompt.output_parser.parse(prediction),
|
||||
)
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
dict[str, Any],
|
||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||
)
|
||||
return output
|
||||
@ -163,14 +163,14 @@ class LLMRouterChain(RouterChain):
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
||||
|
||||
class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
|
||||
class RouterOutputParser(BaseOutputParser[dict[str, str]]):
|
||||
"""Parser for output of router chain in the multi-prompt chain."""
|
||||
|
||||
default_destination: str = "DEFAULT"
|
||||
next_inputs_type: Type = str
|
||||
next_inputs_type: type = str
|
||||
next_inputs_inner_key: str = "input"
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
try:
|
||||
expected_keys = ["destination", "next_inputs"]
|
||||
parsed = parse_and_check_json_markdown(text, expected_keys)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -142,14 +142,14 @@ class MultiPromptChain(MultiRouteChain):
|
||||
""" # noqa: E501
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["text"]
|
||||
|
||||
@classmethod
|
||||
def from_prompts(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt_infos: List[Dict[str, str]],
|
||||
prompt_infos: list[dict[str, str]],
|
||||
default_chain: Optional[Chain] = None,
|
||||
**kwargs: Any,
|
||||
) -> MultiPromptChain:
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
@ -31,14 +32,14 @@ class MultiRetrievalQAChain(MultiRouteChain): # type: ignore[override]
|
||||
"""Default chain to use when router doesn't map input to one of the destinations."""
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
return ["result"]
|
||||
|
||||
@classmethod
|
||||
def from_retrievers(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
retriever_infos: List[Dict[str, Any]],
|
||||
retriever_infos: list[dict[str, Any]],
|
||||
default_retriever: Optional[BaseRetriever] = None,
|
||||
default_prompt: Optional[PromptTemplate] = None,
|
||||
default_chain: Optional[Chain] = None,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Chain pipeline where the outputs of one step feed directly into next."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -16,9 +16,9 @@ from langchain.chains.base import Chain
|
||||
class SequentialChain(Chain):
|
||||
"""Chain where the outputs of one chain feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
input_variables: List[str]
|
||||
output_variables: List[str] #: :meta private:
|
||||
chains: list[Chain]
|
||||
input_variables: list[str]
|
||||
output_variables: list[str] #: :meta private:
|
||||
return_all: bool = False
|
||||
|
||||
model_config = ConfigDict(
|
||||
@ -27,7 +27,7 @@ class SequentialChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Return expected input keys to the chain.
|
||||
|
||||
:meta private:
|
||||
@ -35,7 +35,7 @@ class SequentialChain(Chain):
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -44,7 +44,7 @@ class SequentialChain(Chain):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_chains(cls, values: Dict) -> Any:
|
||||
def validate_chains(cls, values: dict) -> Any:
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chains = values["chains"]
|
||||
input_variables = values["input_variables"]
|
||||
@ -97,9 +97,9 @@ class SequentialChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
known_values = inputs.copy()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
for i, chain in enumerate(self.chains):
|
||||
@ -110,9 +110,9 @@ class SequentialChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
known_values = inputs.copy()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
@ -127,7 +127,7 @@ class SequentialChain(Chain):
|
||||
class SimpleSequentialChain(Chain):
|
||||
"""Simple chain where the outputs of one step feed directly into next."""
|
||||
|
||||
chains: List[Chain]
|
||||
chains: list[Chain]
|
||||
strip_outputs: bool = False
|
||||
input_key: str = "input" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
@ -138,7 +138,7 @@ class SimpleSequentialChain(Chain):
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
@ -146,7 +146,7 @@ class SimpleSequentialChain(Chain):
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
@ -171,9 +171,9 @@ class SimpleSequentialChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
@ -190,9 +190,9 @@ class SimpleSequentialChain(Chain):
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
@ -27,7 +27,7 @@ class SQLInputWithTables(TypedDict):
|
||||
"""Input for a SQL Chain."""
|
||||
|
||||
question: str
|
||||
table_names_to_use: List[str]
|
||||
table_names_to_use: list[str]
|
||||
|
||||
|
||||
def create_sql_query_chain(
|
||||
@ -35,7 +35,7 @@ def create_sql_query_chain(
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
k: int = 5,
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables, dict[str, Any]], str]:
|
||||
"""Create a chain that generates SQL queries.
|
||||
|
||||
*Security Note*: This chain generates SQL queries for the given database.
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.output_parsers import (
|
||||
@ -63,7 +64,7 @@ from pydantic import BaseModel
|
||||
),
|
||||
)
|
||||
def create_openai_fn_runnable(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||
llm: Runnable,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
*,
|
||||
@ -135,7 +136,7 @@ def create_openai_fn_runnable(
|
||||
if not functions:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
openai_functions = [convert_to_openai_function(f) for f in functions]
|
||||
llm_kwargs_: Dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
|
||||
llm_kwargs_: dict[str, Any] = {"functions": openai_functions, **llm_kwargs}
|
||||
if len(openai_functions) == 1 and enforce_single_function_usage:
|
||||
llm_kwargs_["function_call"] = {"name": openai_functions[0]["name"]}
|
||||
output_parser = output_parser or get_openai_output_parser(functions)
|
||||
@ -181,7 +182,7 @@ def create_openai_fn_runnable(
|
||||
),
|
||||
)
|
||||
def create_structured_output_runnable(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||
llm: Runnable,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
*,
|
||||
@ -437,7 +438,7 @@ def create_structured_output_runnable(
|
||||
|
||||
|
||||
def _create_openai_tools_runnable(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
|
||||
tool: Union[dict[str, Any], type[BaseModel], Callable],
|
||||
llm: Runnable,
|
||||
*,
|
||||
prompt: Optional[BasePromptTemplate],
|
||||
@ -446,7 +447,7 @@ def _create_openai_tools_runnable(
|
||||
first_tool_only: bool,
|
||||
) -> Runnable:
|
||||
oai_tool = convert_to_openai_tool(tool)
|
||||
llm_kwargs: Dict[str, Any] = {"tools": [oai_tool]}
|
||||
llm_kwargs: dict[str, Any] = {"tools": [oai_tool]}
|
||||
if enforce_tool_usage:
|
||||
llm_kwargs["tool_choice"] = {
|
||||
"type": "function",
|
||||
@ -462,7 +463,7 @@ def _create_openai_tools_runnable(
|
||||
|
||||
|
||||
def _get_openai_tool_output_parser(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable],
|
||||
tool: Union[dict[str, Any], type[BaseModel], Callable],
|
||||
*,
|
||||
first_tool_only: bool = False,
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
@ -479,7 +480,7 @@ def _get_openai_tool_output_parser(
|
||||
|
||||
|
||||
def get_openai_output_parser(
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable]],
|
||||
functions: Sequence[Union[dict[str, Any], type[BaseModel], Callable]],
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
"""Get the appropriate function output parser given the user functions.
|
||||
|
||||
@ -496,7 +497,7 @@ def get_openai_output_parser(
|
||||
"""
|
||||
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
pydantic_schema: Union[dict, type[BaseModel]] = {
|
||||
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
||||
}
|
||||
else:
|
||||
@ -510,7 +511,7 @@ def get_openai_output_parser(
|
||||
|
||||
|
||||
def _create_openai_json_runnable(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||
llm: Runnable,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
*,
|
||||
@ -537,7 +538,7 @@ def _create_openai_json_runnable(
|
||||
|
||||
|
||||
def _create_openai_functions_structured_output_runnable(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
output_schema: Union[dict[str, Any], type[BaseModel]],
|
||||
llm: Runnable,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
*,
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Load summarizing chains."""
|
||||
|
||||
from typing import Any, Mapping, Optional, Protocol
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -26,13 +27,13 @@ class TransformChain(Chain):
|
||||
output_variables["entities"], transform=func())
|
||||
"""
|
||||
|
||||
input_variables: List[str]
|
||||
input_variables: list[str]
|
||||
"""The keys expected by the transform's input dictionary."""
|
||||
output_variables: List[str]
|
||||
output_variables: list[str]
|
||||
"""The keys returned by the transform's output dictionary."""
|
||||
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
|
||||
transform_cb: Callable[[dict[str, str]], dict[str, str]] = Field(alias="transform")
|
||||
"""The transform function."""
|
||||
atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = (
|
||||
atransform_cb: Optional[Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]] = (
|
||||
Field(None, alias="atransform")
|
||||
)
|
||||
"""The async coroutine transform function."""
|
||||
@ -47,7 +48,7 @@ class TransformChain(Chain):
|
||||
logger.warning(msg)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Expect input keys.
|
||||
|
||||
:meta private:
|
||||
@ -55,7 +56,7 @@ class TransformChain(Chain):
|
||||
return self.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Return output keys.
|
||||
|
||||
:meta private:
|
||||
@ -64,16 +65,16 @@ class TransformChain(Chain):
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
return self.transform_cb(inputs)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
if self.atransform_cb is not None:
|
||||
return await self.atransform_cb(inputs)
|
||||
else:
|
||||
|
@ -1,19 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from importlib import util
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
@ -73,7 +67,7 @@ def init_chat_model(
|
||||
model: Optional[str] = None,
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = ...,
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = ...,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> _ConfigurableModel: ...
|
||||
@ -87,7 +81,7 @@ def init_chat_model(
|
||||
*,
|
||||
model_provider: Optional[str] = None,
|
||||
configurable_fields: Optional[
|
||||
Union[Literal["any"], List[str], Tuple[str, ...]]
|
||||
Union[Literal["any"], list[str], tuple[str, ...]]
|
||||
] = None,
|
||||
config_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
@ -514,7 +508,7 @@ def _attempt_infer_model_provider(model_name: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _parse_model(model: str, model_provider: Optional[str]) -> Tuple[str, str]:
|
||||
def _parse_model(model: str, model_provider: Optional[str]) -> tuple[str, str]:
|
||||
if (
|
||||
not model_provider
|
||||
and ":" in model
|
||||
@ -554,12 +548,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
self,
|
||||
*,
|
||||
default_config: Optional[dict] = None,
|
||||
configurable_fields: Union[Literal["any"], List[str], Tuple[str, ...]] = "any",
|
||||
configurable_fields: Union[Literal["any"], list[str], tuple[str, ...]] = "any",
|
||||
config_prefix: str = "",
|
||||
queued_declarative_operations: Sequence[Tuple[str, Tuple, Dict]] = (),
|
||||
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
|
||||
) -> None:
|
||||
self._default_config: dict = default_config or {}
|
||||
self._configurable_fields: Union[Literal["any"], List[str]] = (
|
||||
self._configurable_fields: Union[Literal["any"], list[str]] = (
|
||||
configurable_fields
|
||||
if configurable_fields == "any"
|
||||
else list(configurable_fields)
|
||||
@ -569,7 +563,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
if config_prefix and not config_prefix.endswith("_")
|
||||
else config_prefix
|
||||
)
|
||||
self._queued_declarative_operations: List[Tuple[str, Tuple, Dict]] = list(
|
||||
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
|
||||
queued_declarative_operations
|
||||
)
|
||||
|
||||
@ -670,7 +664,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
return Union[
|
||||
str,
|
||||
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||
List[AnyMessage],
|
||||
list[AnyMessage],
|
||||
]
|
||||
|
||||
def invoke(
|
||||
@ -708,12 +702,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
@ -731,12 +725,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
inputs: list[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, list[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
@ -759,7 +753,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Tuple[int, Union[Any, Exception]]]:
|
||||
) -> Iterator[tuple[int, Union[Any, Exception]]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
@ -782,7 +776,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Tuple[int, Any]]:
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
config = config or None
|
||||
# If <= 1 config use the underlying models batch implementation.
|
||||
if config is None or isinstance(config, dict) or len(config) <= 1:
|
||||
@ -808,8 +802,7 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Any]:
|
||||
for x in self._model(config).transform(input, config=config, **kwargs):
|
||||
yield x
|
||||
yield from self._model(config).transform(input, config=config, **kwargs)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
@ -915,13 +908,13 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
tools: Sequence[Union[dict[str, Any], type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.__getattr__("bind_tools")(tools, **kwargs)
|
||||
|
||||
# Explicitly added to satisfy downstream linters.
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
return self.__getattr__("with_structured_output")(schema, **kwargs)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
from importlib import util
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -25,7 +25,7 @@ def _get_provider_list() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _parse_model_string(model_name: str) -> Tuple[str, str]:
|
||||
def _parse_model_string(model_name: str) -> tuple[str, str]:
|
||||
"""Parse a model string into provider and model name components.
|
||||
|
||||
The model string should be in the format 'provider:model-name', where provider
|
||||
@ -78,7 +78,7 @@ def _parse_model_string(model_name: str) -> Tuple[str, str]:
|
||||
|
||||
def _infer_model_and_provider(
|
||||
model: str, *, provider: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
if not model.strip():
|
||||
raise ValueError("Model name cannot be empty")
|
||||
if provider is None and ":" in model:
|
||||
@ -122,7 +122,7 @@ def init_embeddings(
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[Embeddings, Runnable[Any, List[float]]]:
|
||||
) -> Union[Embeddings, Runnable[Any, list[float]]]:
|
||||
"""Initialize an embeddings model from a model name and optional provider.
|
||||
|
||||
**Note:** Must have the integration package corresponding to the model provider
|
||||
|
@ -12,8 +12,9 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Sequence, Union, cast
|
||||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
@ -45,9 +46,9 @@ def _value_serializer(value: Sequence[float]) -> bytes:
|
||||
return json.dumps(value).encode()
|
||||
|
||||
|
||||
def _value_deserializer(serialized_value: bytes) -> List[float]:
|
||||
def _value_deserializer(serialized_value: bytes) -> list[float]:
|
||||
"""Deserialize a value."""
|
||||
return cast(List[float], json.loads(serialized_value.decode()))
|
||||
return cast(list[float], json.loads(serialized_value.decode()))
|
||||
|
||||
|
||||
class CacheBackedEmbeddings(Embeddings):
|
||||
@ -88,10 +89,10 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
def __init__(
|
||||
self,
|
||||
underlying_embeddings: Embeddings,
|
||||
document_embedding_store: BaseStore[str, List[float]],
|
||||
document_embedding_store: BaseStore[str, list[float]],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, List[float]]] = None,
|
||||
query_embedding_store: Optional[BaseStore[str, list[float]]] = None,
|
||||
) -> None:
|
||||
"""Initialize the embedder.
|
||||
|
||||
@ -108,7 +109,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
self.underlying_embeddings = underlying_embeddings
|
||||
self.batch_size = batch_size
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts.
|
||||
|
||||
The method first checks the cache for the embeddings.
|
||||
@ -121,10 +122,10 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: List[Union[List[float], None]] = self.document_embedding_store.mget(
|
||||
vectors: list[Union[list[float], None]] = self.document_embedding_store.mget(
|
||||
texts
|
||||
)
|
||||
all_missing_indices: List[int] = [
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
|
||||
@ -138,10 +139,10 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
List[List[float]], vectors
|
||||
list[list[float]], vectors
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a list of texts.
|
||||
|
||||
The method first checks the cache for the embeddings.
|
||||
@ -154,10 +155,10 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
Returns:
|
||||
A list of embeddings for the given texts.
|
||||
"""
|
||||
vectors: List[
|
||||
Union[List[float], None]
|
||||
vectors: list[
|
||||
Union[list[float], None]
|
||||
] = await self.document_embedding_store.amget(texts)
|
||||
all_missing_indices: List[int] = [
|
||||
all_missing_indices: list[int] = [
|
||||
i for i, vector in enumerate(vectors) if vector is None
|
||||
]
|
||||
|
||||
@ -175,10 +176,10 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
vectors[index] = updated_vector
|
||||
|
||||
return cast(
|
||||
List[List[float]], vectors
|
||||
list[list[float]], vectors
|
||||
) # Nones should have been resolved by now
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
@ -201,7 +202,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
self.query_embedding_store.mset([(text, vector)])
|
||||
return vector
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
async def aembed_query(self, text: str) -> list[float]:
|
||||
"""Embed query text.
|
||||
|
||||
By default, this method does not cache queries. To enable caching, set the
|
||||
@ -250,7 +251,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
"""
|
||||
namespace = namespace
|
||||
key_encoder = _create_key_encoder(namespace)
|
||||
document_embedding_store = EncoderBackedStore[str, List[float]](
|
||||
document_embedding_store = EncoderBackedStore[str, list[float]](
|
||||
document_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
@ -261,7 +262,7 @@ class CacheBackedEmbeddings(Embeddings):
|
||||
elif query_embedding_cache is False:
|
||||
query_embedding_store = None
|
||||
else:
|
||||
query_embedding_store = EncoderBackedStore[str, List[float]](
|
||||
query_embedding_store = EncoderBackedStore[str, list[float]](
|
||||
query_embedding_cache,
|
||||
key_encoder,
|
||||
_value_serializer,
|
||||
|
@ -6,13 +6,10 @@ chain (LLMChain) to generate the reasoning and scores.
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
@ -145,7 +142,7 @@ class TrajectoryEvalChain(AgentTrajectoryEvaluator, LLMEvalChain):
|
||||
# 0
|
||||
"""
|
||||
|
||||
agent_tools: Optional[List[BaseTool]] = None
|
||||
agent_tools: Optional[list[BaseTool]] = None
|
||||
"""A list of tools available to the agent."""
|
||||
eval_chain: LLMChain
|
||||
"""The language model chain used for evaluation."""
|
||||
@ -184,7 +181,7 @@ Description: {tool.description}"""
|
||||
|
||||
@staticmethod
|
||||
def get_agent_trajectory(
|
||||
steps: Union[str, Sequence[Tuple[AgentAction, str]]],
|
||||
steps: Union[str, Sequence[tuple[AgentAction, str]]],
|
||||
) -> str:
|
||||
"""Get the agent trajectory as a formatted string.
|
||||
|
||||
@ -263,7 +260,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
def input_keys(self) -> list[str]:
|
||||
"""Get the input keys for the chain.
|
||||
|
||||
Returns:
|
||||
@ -272,7 +269,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
return ["question", "agent_trajectory", "answer", "reference"]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
def output_keys(self) -> list[str]:
|
||||
"""Get the output keys for the chain.
|
||||
|
||||
Returns:
|
||||
@ -280,16 +277,16 @@ The following is the expected answer. Use this to measure correctness:
|
||||
"""
|
||||
return ["score", "reasoning"]
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]:
|
||||
def prep_inputs(self, inputs: Union[dict[str, Any], Any]) -> dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
inputs["reference"] = self._format_reference(inputs.get("reference"))
|
||||
return super().prep_inputs(inputs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run the chain and generate the output.
|
||||
|
||||
Args:
|
||||
@ -311,9 +308,9 @@ The following is the expected answer. Use this to measure correctness:
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
inputs: dict[str, str],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Run the chain and generate the output.
|
||||
|
||||
Args:
|
||||
@ -338,11 +335,11 @@ The following is the expected answer. Use this to measure correctness:
|
||||
*,
|
||||
prediction: str,
|
||||
input: str,
|
||||
agent_trajectory: Sequence[Tuple[AgentAction, str]],
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
reference: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -380,11 +377,11 @@ The following is the expected answer. Use this to measure correctness:
|
||||
*,
|
||||
prediction: str,
|
||||
input: str,
|
||||
agent_trajectory: Sequence[Tuple[AgentAction, str]],
|
||||
agent_trajectory: Sequence[tuple[AgentAction, str]],
|
||||
reference: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain_core.callbacks.manager import Callbacks
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -49,7 +49,7 @@ _SUPPORTED_CRITERIA = {
|
||||
|
||||
|
||||
def resolve_pairwise_criteria(
|
||||
criteria: Optional[Union[CRITERIA_TYPE, str, List[CRITERIA_TYPE]]],
|
||||
criteria: Optional[Union[CRITERIA_TYPE, str, list[CRITERIA_TYPE]]],
|
||||
) -> dict:
|
||||
"""Resolve the criteria for the pairwise evaluator.
|
||||
|
||||
@ -113,7 +113,7 @@ class PairwiseStringResultOutputParser(BaseOutputParser[dict]): # type: ignore[
|
||||
"""
|
||||
return "pairwise_string_result"
|
||||
|
||||
def parse(self, text: str) -> Dict[str, Any]:
|
||||
def parse(self, text: str) -> dict[str, Any]:
|
||||
"""Parse the output text.
|
||||
|
||||
Args:
|
||||
@ -314,8 +314,8 @@ Performance may be significantly worse with other models."
|
||||
input: Optional[str] = None,
|
||||
reference: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
@ -356,8 +356,8 @@ Performance may be significantly worse with other models."
|
||||
reference: Optional[str] = None,
|
||||
input: Optional[str] = None,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user