consolidate run functions (#126)

consolidating logic for when a chain is able to run with single input
text, single output text

open to feedback on naming, logic, usefulness
This commit is contained in:
Harrison Chase 2022-11-13 18:14:35 -08:00 committed by GitHub
parent 1fe3a4f724
commit f23b3ceb49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 39 additions and 143 deletions

View File

@ -46,7 +46,7 @@ if __name__ == "__main__":
try: try:
while True: while True:
browser_content = "\n".join(_crawler.crawl()) browser_content = "\n".join(_crawler.crawl())
llm_command = nat_bot_chain.run(_crawler.page.url, browser_content) llm_command = nat_bot_chain.execute(_crawler.page.url, browser_content)
if not quiet: if not quiet:
print("URL: " + _crawler.page.url) print("URL: " + _crawler.page.url)
print("Objective: " + objective) print("Objective: " + objective)

View File

@ -58,7 +58,7 @@
} }
], ],
"source": [ "source": [
"db_chain.query(\"How many employees are there?\")" "db_chain.run(\"How many employees are there?\")"
] ]
}, },
{ {

View File

@ -86,7 +86,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.7" "version": "3.7.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -35,7 +35,7 @@ class Chain(BaseModel, ABC):
) )
@abstractmethod @abstractmethod
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
@ -43,8 +43,22 @@ class Chain(BaseModel, ABC):
self._validate_inputs(inputs) self._validate_inputs(inputs)
if self.verbose: if self.verbose:
print("\n\n\033[1m> Entering new chain...\033[0m") print("\n\n\033[1m> Entering new chain...\033[0m")
outputs = self._run(inputs) outputs = self._call(inputs)
if self.verbose: if self.verbose:
print("\n\033[1m> Finished chain.\033[0m") print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs) self._validate_outputs(outputs)
return {**inputs, **outputs} return {**inputs, **outputs}
def run(self, text: str) -> str:
"""Run text in, text out (if applicable)."""
if len(self.input_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one input key, got {self.input_keys}."
)
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key, got {self.output_keys}."
)
return self({self.input_keys[0]: text})[self.output_keys[0]]

View File

@ -48,7 +48,7 @@ class LLMChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs) prompt = self.prompt.format(**selected_inputs)

View File

@ -48,7 +48,7 @@ class LLMMathChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonChain() python_executor = PythonChain()
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose) chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
@ -66,19 +66,3 @@ class LLMMathChain(Chain, BaseModel):
else: else:
raise ValueError(f"unknown format from LLM: {t}") raise ValueError(f"unknown format from LLM: {t}")
return {self.output_key: answer} return {self.output_key: answer}
def run(self, question: str) -> str:
"""Understand user question and execute math in Python if necessary.
Args:
question: User question that contains a math question to parse and answer.
Returns:
The answer to the question.
Example:
.. code-block:: python
answer = llm_math.run("What is one plus one?")
"""
return self({self.input_key: question})[self.output_key]

View File

@ -57,7 +57,7 @@ class MapReduceChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
# Split the larger text into smaller chunks. # Split the larger text into smaller chunks.
docs = self.text_splitter.split_text( docs = self.text_splitter.split_text(
inputs[self.input_key], inputs[self.input_key],
@ -76,7 +76,3 @@ class MapReduceChain(Chain, BaseModel):
inputs = {self.reduce_llm.prompt.input_variables[0]: summary_str} inputs = {self.reduce_llm.prompt.input_variables[0]: summary_str}
output = self.reduce_llm.predict(**inputs) output = self.reduce_llm.predict(**inputs)
return {self.output_key: output} return {self.output_key: output}
def run(self, text: str) -> str:
"""Run the map-reduce logic on the input text."""
return self({self.input_key: text})[self.output_key]

View File

@ -147,7 +147,7 @@ class MRKLChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
chained_input = ChainedInput( chained_input = ChainedInput(
f"{inputs[self.input_key]}\nThought:", verbose=self.verbose f"{inputs[self.input_key]}\nThought:", verbose=self.verbose
@ -168,7 +168,3 @@ class MRKLChain(Chain, BaseModel):
chained_input.add("\nObservation: ") chained_input.add("\nObservation: ")
chained_input.add(ca, color=color_mapping[action]) chained_input.add(ca, color=color_mapping[action])
chained_input.add("\nThought:") chained_input.add("\nThought:")
def run(self, _input: str) -> str:
"""Run input through the MRKL system."""
return self({self.input_key: _input})[self.output_key]

View File

@ -57,7 +57,7 @@ class NatBotChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
url = inputs[self.input_url_key] url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key] browser_content = inputs[self.input_browser_content_key]
@ -71,7 +71,7 @@ class NatBotChain(Chain, BaseModel):
self.previous_command = llm_cmd self.previous_command = llm_cmd
return {self.output_key: llm_cmd} return {self.output_key: llm_cmd}
def run(self, url: str, browser_content: str) -> str: def execute(self, url: str, browser_content: str) -> str:
"""Figure out next browser command to run. """Figure out next browser command to run.
Args: Args:

View File

@ -41,7 +41,7 @@ class PythonChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
python_repl = PythonREPL() python_repl = PythonREPL()
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = mystdout = StringIO() sys.stdout = mystdout = StringIO()
@ -49,20 +49,3 @@ class PythonChain(Chain, BaseModel):
sys.stdout = old_stdout sys.stdout = old_stdout
output = mystdout.getvalue() output = mystdout.getvalue()
return {self.output_key: output} return {self.output_key: output}
def run(self, code: str) -> str:
"""Run code in python interpreter.
Args:
code: Code snippet to execute, should print out the answer.
Returns:
Answer from running the code and printing out the answer.
Example:
.. code-block:: python
answer = python_chain.run("print(1+1)")
"""
return self({self.input_key: code})[self.output_key]

View File

@ -72,9 +72,9 @@ class ReActChain(Chain, BaseModel):
:meta private: :meta private:
""" """
return ["full_logic", self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
question = inputs[self.input_key] question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose) chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose)
@ -98,27 +98,10 @@ class ReActChain(Chain, BaseModel):
raise ValueError("Cannot lookup without a successful search first") raise ValueError("Cannot lookup without a successful search first")
observation = document.lookup(directive) observation = document.lookup(directive)
elif action == "Finish": elif action == "Finish":
return {"full_logic": chained_input.input, self.output_key: directive} return {self.output_key: directive}
else: else:
raise ValueError(f"Got unknown action directive: {action}") raise ValueError(f"Got unknown action directive: {action}")
chained_input.add(f"\nObservation {i}: ") chained_input.add(f"\nObservation {i}: ")
chained_input.add(observation, color="yellow") chained_input.add(observation, color="yellow")
chained_input.add(f"\nThought {i + 1}:") chained_input.add(f"\nThought {i + 1}:")
i += 1 i += 1
def run(self, question: str) -> str:
"""Run ReAct framework.
Args:
question: Question to be answered.
Returns:
Final answer from thinking through the ReAct framework.
Example:
.. code-block:: python
question = "Were Scott Derrickson and Ed Wood of the same nationality?"
answer = react.run(question)
"""
return self({self.input_key: question})[self.output_key]

View File

@ -114,7 +114,7 @@ class SelfAskWithSearchChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose) chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
chained_input.add("\nAre follow up questions needed here:") chained_input.add("\nAre follow up questions needed here:")
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
@ -125,7 +125,7 @@ class SelfAskWithSearchChain(Chain, BaseModel):
chained_input.add(ret_text, color="green") chained_input.add(ret_text, color="green")
while followup in get_last_line(ret_text): while followup in get_last_line(ret_text):
question = extract_question(ret_text, followup) question = extract_question(ret_text, followup)
external_answer = self.search_chain.search(question) external_answer = self.search_chain.run(question)
if external_answer is not None: if external_answer is not None:
chained_input.add(intermediate + " ") chained_input.add(intermediate + " ")
chained_input.add(external_answer + ".", color="yellow") chained_input.add(external_answer + ".", color="yellow")
@ -147,19 +147,3 @@ class SelfAskWithSearchChain(Chain, BaseModel):
chained_input.add(ret_text, color="green") chained_input.add(ret_text, color="green")
return {self.output_key: ret_text} return {self.output_key: ret_text}
def run(self, question: str) -> str:
"""Run self ask with search chain.
Args:
question: Question to run self-ask-with-search with.
Returns:
The final answer
Example:
.. code-block:: python
answer = selfask.run("What is the capital of Idaho?")
"""
return self({self.input_key: question})[self.output_key]

View File

@ -88,7 +88,7 @@ class SerpAPIChain(Chain, BaseModel):
) )
return values return values
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
params = { params = {
"api_key": self.serpapi_api_key, "api_key": self.serpapi_api_key,
"engine": "google", "engine": "google",
@ -116,19 +116,3 @@ class SerpAPIChain(Chain, BaseModel):
else: else:
toret = None toret = None
return {self.output_key: toret} return {self.output_key: toret}
def search(self, search_question: str) -> str:
"""Run search query against SerpAPI.
Args:
search_question: Question to run against the SerpAPI.
Returns:
Answer from the search engine.
Example:
.. code-block:: python
answer = serpapi.search("What is the capital of Idaho?")
"""
return self({self.input_key: search_question})[self.output_key]

View File

@ -51,7 +51,7 @@ class SQLDatabaseChain(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput( chained_input = ChainedInput(
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
@ -72,19 +72,3 @@ class SQLDatabaseChain(Chain, BaseModel):
final_result = llm_chain.predict(**llm_inputs) final_result = llm_chain.predict(**llm_inputs)
chained_input.add(final_result, color="green") chained_input.add(final_result, color="green")
return {self.output_key: final_result} return {self.output_key: final_result}
def query(self, query: str) -> str:
"""Run natural language query against a SQL database.
Args:
query: natural language query to run against the SQL database
Returns:
The final answer as derived from the SQL database.
Example:
.. code-block:: python
answer = db_chain.query("How many customers are there?")
"""
return self({self.input_key: query})[self.output_key]

View File

@ -52,7 +52,7 @@ class VectorDBQA(Chain, BaseModel):
""" """
return [self.output_key] return [self.output_key]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key] question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=prompt) llm_chain = LLMChain(llm=self.llm, prompt=prompt)
docs = self.vectorstore.similarity_search(question) docs = self.vectorstore.similarity_search(question)

View File

@ -5,5 +5,5 @@ from langchain.chains.serpapi import SerpAPIChain
def test_call() -> None: def test_call() -> None:
"""Test that call gives the correct answer.""" """Test that call gives the correct answer."""
chain = SerpAPIChain() chain = SerpAPIChain()
output = chain.search("What was Obama's first name?") output = chain.run("What was Obama's first name?")
assert output == "Barack Hussein Obama II" assert output == "Barack Hussein Obama II"

View File

@ -25,6 +25,6 @@ def test_sql_database_run() -> None:
conn.execute(stmt) conn.execute(stmt)
db = SQLDatabase(engine) db = SQLDatabase(engine)
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
output = db_chain.query("What company does Harrison work at?") output = db_chain.run("What company does Harrison work at?")
expected_output = " Harrison works at Foo." expected_output = " Harrison works at Foo."
assert output == expected_output assert output == expected_output

View File

@ -22,7 +22,7 @@ class FakeChain(Chain, BaseModel):
"""Output key of bar.""" """Output key of bar."""
return ["bar"] return ["bar"]
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
if self.be_correct: if self.be_correct:
return {"bar": "baz"} return {"bar": "baz"}
else: else:

View File

@ -26,7 +26,7 @@ def test_proper_inputs() -> None:
nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing") nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
url = "foo" * 10000 url = "foo" * 10000
browser_content = "foo" * 10000 browser_content = "foo" * 10000
output = nat_bot_chain.run(url, browser_content) output = nat_bot_chain.execute(url, browser_content)
assert output == "bar" assert output == "bar"
@ -39,5 +39,5 @@ def test_variable_key_naming() -> None:
input_browser_content_key="b", input_browser_content_key="b",
output_key="c", output_key="c",
) )
output = nat_bot_chain.run("foo", "foo") output = nat_bot_chain.execute("foo", "foo")
assert output == "bar" assert output == "bar"

View File

@ -92,18 +92,6 @@ def test_react_chain() -> None:
inputs = {"question": "when was langchain made"} inputs = {"question": "when was langchain made"}
output = react_chain(inputs) output = react_chain(inputs)
assert output["answer"] == "2022" assert output["answer"] == "2022"
expected_full_output = (
"when was langchain made\n"
"Thought 1:I should probably search\n"
"Action 1: Search[langchain]\n"
"Observation 1: This is a page about LangChain.\n"
"Thought 2:I should probably lookup\n"
"Action 2: Lookup[made]\n"
"Observation 2: (Result 1/1) Made in 2022.\n"
"Thought 3:Ah okay now I know the answer\n"
"Action 3: Finish[2022]"
)
assert output["full_logic"] == expected_full_output
def test_react_chain_bad_action() -> None: def test_react_chain_bad_action() -> None: