This commit is contained in:
Harrison Chase 2022-12-10 23:58:02 -08:00
parent aacd417076
commit 81383474c4
6 changed files with 29 additions and 13 deletions

View File

@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 2,
"id": "07e96d99", "id": "07e96d99",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -63,7 +63,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 3,
"id": "a069c4b6", "id": "a069c4b6",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -73,7 +73,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 4,
"id": "e603cd7d", "id": "e603cd7d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -121,7 +121,7 @@
"\"Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\"" "\"Harry Styles, Olivia Wilde's boyfriend, is 28 years old and his age raised to the 0.23 power is 2.1520202182226886.\""
] ]
}, },
"execution_count": 29, "execution_count": 4,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -132,7 +132,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 5,
"id": "a5c07010", "id": "a5c07010",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -170,7 +170,7 @@
"\"Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\"" "\"Alanis Morissette's album 'Jagged Little Pill' is in the FooBar database.\""
] ]
}, },
"execution_count": 25, "execution_count": 5,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -182,7 +182,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "3f13b1c3", "id": "af016a70",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, List, Optional, Tuple from typing import Any, ClassVar, Dict, List, Optional, Tuple
from pydantic import BaseModel from pydantic import BaseModel, root_validator
from langchain.agents.input import ChainedInput from langchain.agents.input import ChainedInput
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
@ -41,6 +41,16 @@ class Agent(Chain, BaseModel, ABC):
""" """
return [self.output_key] return [self.output_key]
@root_validator()
def validate_prompt(cls, values: Dict) -> Dict:
"""Validate that prompt matches format."""
prompt = values["llm_chain"].prompt
if "agent_scratchpad" not in prompt.input_variables:
raise ValueError(
"`agent_scratchpad` should be a variable in prompt.input_variables"
)
return values
@property @property
@abstractmethod @abstractmethod
def observation_prefix(self) -> str: def observation_prefix(self) -> str:

View File

@ -47,4 +47,6 @@ Action 4: Finish[yes]
SUFFIX = """\n\nSetup: {input} SUFFIX = """\n\nSetup: {input}
{agent_scratchpad}""" {agent_scratchpad}"""
TEXTWORLD_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) TEXTWORLD_PROMPT = PromptTemplate.from_examples(
EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]
)

View File

@ -110,4 +110,6 @@ Action 3: Finish[yes]""",
SUFFIX = """\n\nQuestion: {input} SUFFIX = """\n\nQuestion: {input}
{agent_scratchpad}""" {agent_scratchpad}"""
WIKI_PROMPT = PromptTemplate.from_examples(EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]) WIKI_PROMPT = PromptTemplate.from_examples(
EXAMPLES, SUFFIX, ["input", "agent_scratchpad"]
)

View File

@ -39,4 +39,6 @@ So the final answer is: No
Question: {input} Question: {input}
{agent_scratchpad}""" {agent_scratchpad}"""
PROMPT = PromptTemplate(input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE) PROMPT = PromptTemplate(
input_variables=["input", "agent_scratchpad"], template=_DEFAULT_TEMPLATE
)

View File

@ -56,7 +56,7 @@ def test_predict_until_observation_normal() -> None:
Tool("Lookup", lambda x: x), Tool("Lookup", lambda x: x),
] ]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("") output = agent.get_action("", {"input": ""})
assert output.log == outputs[0] assert output.log == outputs[0]
assert output.tool == "Search" assert output.tool == "Search"
assert output.tool_input == "foo" assert output.tool_input == "foo"
@ -71,7 +71,7 @@ def test_predict_until_observation_repeat() -> None:
Tool("Lookup", lambda x: x), Tool("Lookup", lambda x: x),
] ]
agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools) agent = ReActDocstoreAgent.from_llm_and_tools(fake_llm, tools)
output = agent.get_action("") output = agent.get_action("", {"input": ""})
assert output.log == "foo\nAction 1: Search[foo]" assert output.log == "foo\nAction 1: Search[foo]"
assert output.tool == "Search" assert output.tool == "Search"
assert output.tool_input == "foo" assert output.tool_input == "foo"