mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-18 21:41:24 +00:00
agent refactor (#2801)
This commit is contained in:
parent
abfca72c0b
commit
b2bc5ef56a
265
docs/modules/agents/agents/examples/mrkl_chat_002.ipynb
Normal file
265
docs/modules/agents/agents/examples/mrkl_chat_002.ipynb
Normal file
@ -0,0 +1,265 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f1390152",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MRKL Chat\n",
|
||||
"\n",
|
||||
"This notebook showcases using an agent to replicate the MRKL chain using an agent optimized for chat models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "39ea3638",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This uses the example Chinook database.\n",
|
||||
"To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the `.db` file in a notebooks folder at the root of this repository."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "ac561cc4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain import OpenAI, LLMMathChain, SerpAPIWrapper, SQLDatabase, SQLDatabaseChain\n",
|
||||
"from langchain.agents import initialize_agent, Tool\n",
|
||||
"from langchain.agents import AgentType\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "07e96d99",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"llm1 = OpenAI(temperature=0)\n",
|
||||
"search = SerpAPIWrapper()\n",
|
||||
"llm_math_chain = LLMMathChain(llm=llm1, verbose=True)\n",
|
||||
"db = SQLDatabase.from_uri(\"sqlite:///../../../../../notebooks/Chinook.db\")\n",
|
||||
"db_chain = SQLDatabaseChain(llm=llm1, database=db, verbose=True)\n",
|
||||
"tools = [\n",
|
||||
" Tool(\n",
|
||||
" name = \"Search\",\n",
|
||||
" func=search.run,\n",
|
||||
" description=\"useful for when you need to answer questions about current events. You should ask targeted questions\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"Calculator\",\n",
|
||||
" func=llm_math_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about math\"\n",
|
||||
" ),\n",
|
||||
" Tool(\n",
|
||||
" name=\"FooBar DB\",\n",
|
||||
" func=db_chain.run,\n",
|
||||
" description=\"useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context\"\n",
|
||||
" )\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a069c4b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mrkl = initialize_agent(tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "e603cd7d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThought: I need to use the Search tool to find out who is Leo DiCaprio's girlfriend and then use the Calculator tool to raise her current age to the 0.43 power.\n",
|
||||
"\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Search\",\n",
|
||||
" \"action_input\": \"Who is Leo DiCaprio's girlfriend?\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation:\u001b[36;1m\u001b[1;3mLeo briefly dated Blake Lively when she was 23 and he was 36, then jumped to Erin Heatherton, who was 22, and was with Toni Garrin from 2013 to ...\u001b[0m\u001b[32;1m\u001b[1;3mI need to search again to find out who is Leo DiCaprio's current girlfriend. \n",
|
||||
"\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Search\",\n",
|
||||
" \"action_input\": \"Who is Leo DiCaprio's current girlfriend?\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation:\u001b[36;1m\u001b[1;3mLeonardo DiCaprio started dating 20-year-old model Camila Morrone in late 2017. They're still together today - and she's 22 now.\u001b[0m\u001b[32;1m\u001b[1;3mNow I need to use the Calculator tool to raise Camila Morrone's current age to the 0.43 power.\n",
|
||||
"\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Calculator\",\n",
|
||||
" \"action_input\": \"22^0.43\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"22^0.43\u001b[32;1m\u001b[1;3m\n",
|
||||
"```python\n",
|
||||
"import math\n",
|
||||
"print(math.pow(22, 0.43))\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m3.777824273683966\n",
|
||||
"\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation:\u001b[33;1m\u001b[1;3mAnswer: 3.777824273683966\n",
|
||||
"\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n",
|
||||
"\n",
|
||||
"Final Answer: 3.777824273683966\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'3.777824273683966'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrkl.run(\"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "a5c07010",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mQuestion: What is the full name of the artist who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\n",
|
||||
"Thought: I should use the Search tool to find the answer to the first part of the question and then use the FooBar DB tool to find the answer to the second part.\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"Search\",\n",
|
||||
" \"action_input\": \"Who recently released an album called 'The Storm Before the Calm'\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\u001b[0m\n",
|
||||
"Observation:\u001b[36;1m\u001b[1;3mAlanis Morissette\u001b[0m\u001b[32;1m\u001b[1;3mNow that I know the artist's name, I can use the FooBar DB tool to find out if they are in the database and what albums of theirs are in it.\n",
|
||||
"Action:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"action\": \"FooBar DB\",\n",
|
||||
" \"action_input\": \"What albums does Alanis Morissette have in the database?\"\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
|
||||
"What albums does Alanis Morissette have in the database? \n",
|
||||
"SQLQuery:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/sqlalchemy/sql/sqltypes.py:726: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.\n",
|
||||
" util.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[32;1m\u001b[1;3m SELECT Title FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Alanis Morissette') LIMIT 5;\u001b[0m\n",
|
||||
"SQLResult: \u001b[33;1m\u001b[1;3m[('Jagged Little Pill',)]\u001b[0m\n",
|
||||
"Answer:\u001b[32;1m\u001b[1;3m Alanis Morissette has the album 'Jagged Little Pill' in the database.\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"Observation:\u001b[38;5;200m\u001b[1;3m Alanis Morissette has the album 'Jagged Little Pill' in the database.\u001b[0m\u001b[32;1m\u001b[1;3mI have found the answer to both parts of the question.\n",
|
||||
"Final Answer: The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette. The album 'Jagged Little Pill' is in the FooBar database.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"The artist who recently released an album called 'The Storm Before the Calm' is Alanis Morissette. The album 'Jagged Little Pill' is in the FooBar database.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mrkl.run(\"What is the full name of the artist who recently released an album called 'The Storm Before the Calm' and are they in the FooBar database? If so, what albums of theirs are in the FooBar database?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "af016a70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -99,6 +99,16 @@ class BaseSingleActionAgent(BaseModel):
|
||||
f"Got unsupported early_stopping_method `{early_stopping_method}`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
"""Return Identifier of agent type."""
|
||||
|
@ -8,3 +8,4 @@ class AgentType(str, Enum):
|
||||
CONVERSATIONAL_REACT_DESCRIPTION = "conversational-react-description"
|
||||
CHAT_ZERO_SHOT_REACT_DESCRIPTION = "chat-zero-shot-react-description"
|
||||
CHAT_CONVERSATIONAL_REACT_DESCRIPTION = "chat-conversational-react-description"
|
||||
CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2 = "chat-zero-shot-react-description-002"
|
||||
|
0
langchain/agents/chat_v2/__init__.py
Normal file
0
langchain/agents/chat_v2/__init__.py
Normal file
53
langchain/agents/chat_v2/base.py
Normal file
53
langchain/agents/chat_v2/base.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
from langchain.agents.agent import AgentOutputParser, LLMSingleActionAgent
|
||||
from langchain.agents.chat_v2.prompt import (
|
||||
FORMAT_INSTRUCTIONS,
|
||||
PREFIX,
|
||||
SUFFIX,
|
||||
ChatOutputParser,
|
||||
create_prompt,
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class ChatAgentV2(LLMSingleActionAgent):
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMSingleActionAgent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
_stop = stop or ["Observation:"]
|
||||
_output_parser = output_parser or ChatOutputParser()
|
||||
prompt = create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
return cls(
|
||||
llm_chain=llm_chain, output_parser=_output_parser, stop=_stop, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def _agent_type(self) -> str:
|
||||
raise ValueError
|
84
langchain/agents/chat_v2/prompt.py
Normal file
84
langchain/agents/chat_v2/prompt.py
Normal file
@ -0,0 +1,84 @@
|
||||
# flake8: noqa
|
||||
import json
|
||||
from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.agents.schema import AgentScratchPadChatPromptTemplate
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import AgentAction, AgentFinish
|
||||
from langchain.tools.base import BaseTool
|
||||
from typing import Sequence, Optional, List, Union
|
||||
from langchain.agents.agent import AgentOutputParser
|
||||
|
||||
PREFIX = """Answer the following questions as best you can. You have access to the following tools:"""
|
||||
FORMAT_INSTRUCTIONS = """The way you use the tools is by specifying a json blob.
|
||||
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
|
||||
|
||||
The only values that should be in the "action" field are: {tool_names}
|
||||
|
||||
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
ALWAYS use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Observation can repeat N times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question"""
|
||||
SUFFIX = """Begin! Reminder to always use the exact characters `Final Answer` when responding."""
|
||||
|
||||
|
||||
def create_prompt(
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
HumanMessagePromptTemplate.from_template("{input}\n\n{agent_scratchpad}"),
|
||||
]
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "intermediate_steps"]
|
||||
return AgentScratchPadChatPromptTemplate(
|
||||
input_variables=input_variables, messages=messages
|
||||
)
|
||||
|
||||
|
||||
class ChatOutputParser(AgentOutputParser):
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
if "Final Answer:" in text:
|
||||
return AgentFinish(
|
||||
# Return values is generally always a dictionary with a single `output` key
|
||||
# It is not recommended to try anything else at the moment :)
|
||||
return_values={"output": text.split("Final Answer:")[-1].strip()},
|
||||
log=text,
|
||||
)
|
||||
try:
|
||||
_, action, _ = text.split("```")
|
||||
response = json.loads(action.strip())
|
||||
agent_action = AgentAction(
|
||||
tool=response["action"], tool_input=response["action_input"], log=text
|
||||
)
|
||||
return agent_action
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"Could not parse LLM output: {text}")
|
@ -5,9 +5,10 @@ from typing import Any, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.agents.agent import BaseSingleActionAgent
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.chat.base import ChatAgent
|
||||
from langchain.agents.chat_v2.base import ChatAgentV2
|
||||
from langchain.agents.conversational.base import ConversationalAgent
|
||||
from langchain.agents.conversational_chat.base import ConversationalChatAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
@ -25,6 +26,7 @@ AGENT_TO_CLASS = {
|
||||
AgentType.CONVERSATIONAL_REACT_DESCRIPTION: ConversationalAgent,
|
||||
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION: ChatAgent,
|
||||
AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION: ConversationalChatAgent,
|
||||
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION_V2: ChatAgentV2,
|
||||
}
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
|
||||
@ -32,7 +34,7 @@ URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/age
|
||||
|
||||
def _load_agent_from_tools(
|
||||
config: dict, llm: BaseLLM, tools: List[Tool], **kwargs: Any
|
||||
) -> Agent:
|
||||
) -> BaseSingleActionAgent:
|
||||
config_type = config.pop("_type")
|
||||
if config_type not in AGENT_TO_CLASS:
|
||||
raise ValueError(f"Loading {config_type} agent not supported")
|
||||
@ -49,7 +51,7 @@ def load_agent_from_config(
|
||||
llm: Optional[BaseLLM] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
) -> BaseSingleActionAgent:
|
||||
"""Load agent from Config Dict."""
|
||||
if "_type" not in config:
|
||||
raise ValueError("Must specify an agent Type in config")
|
||||
@ -82,7 +84,7 @@ def load_agent_from_config(
|
||||
return agent_cls(**combined_config) # type: ignore
|
||||
|
||||
|
||||
def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
def load_agent(path: Union[str, Path], **kwargs: Any) -> BaseSingleActionAgent:
|
||||
"""Unified method for loading a agent from LangChainHub or local fs."""
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_agent_from_file, "agents", {"json", "yaml"}
|
||||
@ -92,7 +94,9 @@ def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
return _load_agent_from_file(path, **kwargs)
|
||||
|
||||
|
||||
def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
def _load_agent_from_file(
|
||||
file: Union[str, Path], **kwargs: Any
|
||||
) -> BaseSingleActionAgent:
|
||||
"""Load agent from file."""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file, str):
|
||||
|
28
langchain/agents/schema.py
Normal file
28
langchain/agents/schema.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
|
||||
class AgentScratchPadChatPromptTemplate(ChatPromptTemplate):
|
||||
def _construct_agent_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
if len(intermediate_steps) == 0:
|
||||
return ""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought: "
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{thoughts}"
|
||||
)
|
||||
|
||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
intermediate_steps = kwargs.pop("intermediate_steps")
|
||||
kwargs["agent_scratchpad"] = self._construct_agent_scratchpad(
|
||||
intermediate_steps
|
||||
)
|
||||
return kwargs
|
Loading…
Reference in New Issue
Block a user