mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
consolidate run functions (#126)
consolidating logic for when a chain is able to run with single input text, single output text open to feedback on naming, logic, usefulness
This commit is contained in:
parent
1fe3a4f724
commit
f23b3ceb49
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
while True:
|
||||
browser_content = "\n".join(_crawler.crawl())
|
||||
llm_command = nat_bot_chain.run(_crawler.page.url, browser_content)
|
||||
llm_command = nat_bot_chain.execute(_crawler.page.url, browser_content)
|
||||
if not quiet:
|
||||
print("URL: " + _crawler.page.url)
|
||||
print("Objective: " + objective)
|
||||
|
@ -58,7 +58,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"db_chain.query(\"How many employees are there?\")"
|
||||
"db_chain.run(\"How many employees are there?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -86,7 +86,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.7"
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -35,7 +35,7 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
@ -43,8 +43,22 @@ class Chain(BaseModel, ABC):
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
outputs = self._run(inputs)
|
||||
outputs = self._call(inputs)
|
||||
if self.verbose:
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def run(self, text: str) -> str:
|
||||
"""Run text in, text out (if applicable)."""
|
||||
if len(self.input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one input key, got {self.input_keys}."
|
||||
)
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key, got {self.output_keys}."
|
||||
)
|
||||
return self({self.input_keys[0]: text})[self.output_keys[0]]
|
||||
|
@ -48,7 +48,7 @@ class LLMChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
|
||||
|
@ -48,7 +48,7 @@ class LLMMathChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonChain()
|
||||
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
|
||||
@ -66,19 +66,3 @@ class LLMMathChain(Chain, BaseModel):
|
||||
else:
|
||||
raise ValueError(f"unknown format from LLM: {t}")
|
||||
return {self.output_key: answer}
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""Understand user question and execute math in Python if necessary.
|
||||
|
||||
Args:
|
||||
question: User question that contains a math question to parse and answer.
|
||||
|
||||
Returns:
|
||||
The answer to the question.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = llm_math.run("What is one plus one?")
|
||||
"""
|
||||
return self({self.input_key: question})[self.output_key]
|
||||
|
@ -57,7 +57,7 @@ class MapReduceChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
# Split the larger text into smaller chunks.
|
||||
docs = self.text_splitter.split_text(
|
||||
inputs[self.input_key],
|
||||
@ -76,7 +76,3 @@ class MapReduceChain(Chain, BaseModel):
|
||||
inputs = {self.reduce_llm.prompt.input_variables[0]: summary_str}
|
||||
output = self.reduce_llm.predict(**inputs)
|
||||
return {self.output_key: output}
|
||||
|
||||
def run(self, text: str) -> str:
|
||||
"""Run the map-reduce logic on the input text."""
|
||||
return self({self.input_key: text})[self.output_key]
|
||||
|
@ -147,7 +147,7 @@ class MRKLChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
|
||||
chained_input = ChainedInput(
|
||||
f"{inputs[self.input_key]}\nThought:", verbose=self.verbose
|
||||
@ -168,7 +168,3 @@ class MRKLChain(Chain, BaseModel):
|
||||
chained_input.add("\nObservation: ")
|
||||
chained_input.add(ca, color=color_mapping[action])
|
||||
chained_input.add("\nThought:")
|
||||
|
||||
def run(self, _input: str) -> str:
|
||||
"""Run input through the MRKL system."""
|
||||
return self({self.input_key: _input})[self.output_key]
|
||||
|
@ -57,7 +57,7 @@ class NatBotChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
url = inputs[self.input_url_key]
|
||||
browser_content = inputs[self.input_browser_content_key]
|
||||
@ -71,7 +71,7 @@ class NatBotChain(Chain, BaseModel):
|
||||
self.previous_command = llm_cmd
|
||||
return {self.output_key: llm_cmd}
|
||||
|
||||
def run(self, url: str, browser_content: str) -> str:
|
||||
def execute(self, url: str, browser_content: str) -> str:
|
||||
"""Figure out next browser command to run.
|
||||
|
||||
Args:
|
||||
|
@ -41,7 +41,7 @@ class PythonChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
python_repl = PythonREPL()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
@ -49,20 +49,3 @@ class PythonChain(Chain, BaseModel):
|
||||
sys.stdout = old_stdout
|
||||
output = mystdout.getvalue()
|
||||
return {self.output_key: output}
|
||||
|
||||
def run(self, code: str) -> str:
|
||||
"""Run code in python interpreter.
|
||||
|
||||
Args:
|
||||
code: Code snippet to execute, should print out the answer.
|
||||
|
||||
Returns:
|
||||
Answer from running the code and printing out the answer.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
answer = python_chain.run("print(1+1)")
|
||||
"""
|
||||
return self({self.input_key: code})[self.output_key]
|
||||
|
@ -72,9 +72,9 @@ class ReActChain(Chain, BaseModel):
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return ["full_logic", self.output_key]
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
chained_input = ChainedInput(f"{question}\nThought 1:", verbose=self.verbose)
|
||||
@ -98,27 +98,10 @@ class ReActChain(Chain, BaseModel):
|
||||
raise ValueError("Cannot lookup without a successful search first")
|
||||
observation = document.lookup(directive)
|
||||
elif action == "Finish":
|
||||
return {"full_logic": chained_input.input, self.output_key: directive}
|
||||
return {self.output_key: directive}
|
||||
else:
|
||||
raise ValueError(f"Got unknown action directive: {action}")
|
||||
chained_input.add(f"\nObservation {i}: ")
|
||||
chained_input.add(observation, color="yellow")
|
||||
chained_input.add(f"\nThought {i + 1}:")
|
||||
i += 1
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""Run ReAct framework.
|
||||
|
||||
Args:
|
||||
question: Question to be answered.
|
||||
|
||||
Returns:
|
||||
Final answer from thinking through the ReAct framework.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
question = "Were Scott Derrickson and Ed Wood of the same nationality?"
|
||||
answer = react.run(question)
|
||||
"""
|
||||
return self({self.input_key: question})[self.output_key]
|
||||
|
@ -114,7 +114,7 @@ class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
|
||||
chained_input.add("\nAre follow up questions needed here:")
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
@ -125,7 +125,7 @@ class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
chained_input.add(ret_text, color="green")
|
||||
while followup in get_last_line(ret_text):
|
||||
question = extract_question(ret_text, followup)
|
||||
external_answer = self.search_chain.search(question)
|
||||
external_answer = self.search_chain.run(question)
|
||||
if external_answer is not None:
|
||||
chained_input.add(intermediate + " ")
|
||||
chained_input.add(external_answer + ".", color="yellow")
|
||||
@ -147,19 +147,3 @@ class SelfAskWithSearchChain(Chain, BaseModel):
|
||||
chained_input.add(ret_text, color="green")
|
||||
|
||||
return {self.output_key: ret_text}
|
||||
|
||||
def run(self, question: str) -> str:
|
||||
"""Run self ask with search chain.
|
||||
|
||||
Args:
|
||||
question: Question to run self-ask-with-search with.
|
||||
|
||||
Returns:
|
||||
The final answer
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = selfask.run("What is the capital of Idaho?")
|
||||
"""
|
||||
return self({self.input_key: question})[self.output_key]
|
||||
|
@ -88,7 +88,7 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"engine": "google",
|
||||
@ -116,19 +116,3 @@ class SerpAPIChain(Chain, BaseModel):
|
||||
else:
|
||||
toret = None
|
||||
return {self.output_key: toret}
|
||||
|
||||
def search(self, search_question: str) -> str:
|
||||
"""Run search query against SerpAPI.
|
||||
|
||||
Args:
|
||||
search_question: Question to run against the SerpAPI.
|
||||
|
||||
Returns:
|
||||
Answer from the search engine.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = serpapi.search("What is the capital of Idaho?")
|
||||
"""
|
||||
return self({self.input_key: search_question})[self.output_key]
|
||||
|
@ -51,7 +51,7 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
chained_input = ChainedInput(
|
||||
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
|
||||
@ -72,19 +72,3 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
final_result = llm_chain.predict(**llm_inputs)
|
||||
chained_input.add(final_result, color="green")
|
||||
return {self.output_key: final_result}
|
||||
|
||||
def query(self, query: str) -> str:
|
||||
"""Run natural language query against a SQL database.
|
||||
|
||||
Args:
|
||||
query: natural language query to run against the SQL database
|
||||
|
||||
Returns:
|
||||
The final answer as derived from the SQL database.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
answer = db_chain.query("How many customers are there?")
|
||||
"""
|
||||
return self({self.input_key: query})[self.output_key]
|
||||
|
@ -52,7 +52,7 @@ class VectorDBQA(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
question = inputs[self.input_key]
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
||||
docs = self.vectorstore.similarity_search(question)
|
||||
|
@ -5,5 +5,5 @@ from langchain.chains.serpapi import SerpAPIChain
|
||||
def test_call() -> None:
|
||||
"""Test that call gives the correct answer."""
|
||||
chain = SerpAPIChain()
|
||||
output = chain.search("What was Obama's first name?")
|
||||
output = chain.run("What was Obama's first name?")
|
||||
assert output == "Barack Hussein Obama II"
|
||||
|
@ -25,6 +25,6 @@ def test_sql_database_run() -> None:
|
||||
conn.execute(stmt)
|
||||
db = SQLDatabase(engine)
|
||||
db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db)
|
||||
output = db_chain.query("What company does Harrison work at?")
|
||||
output = db_chain.run("What company does Harrison work at?")
|
||||
expected_output = " Harrison works at Foo."
|
||||
assert output == expected_output
|
||||
|
@ -22,7 +22,7 @@ class FakeChain(Chain, BaseModel):
|
||||
"""Output key of bar."""
|
||||
return ["bar"]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
|
@ -26,7 +26,7 @@ def test_proper_inputs() -> None:
|
||||
nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
|
||||
url = "foo" * 10000
|
||||
browser_content = "foo" * 10000
|
||||
output = nat_bot_chain.run(url, browser_content)
|
||||
output = nat_bot_chain.execute(url, browser_content)
|
||||
assert output == "bar"
|
||||
|
||||
|
||||
@ -39,5 +39,5 @@ def test_variable_key_naming() -> None:
|
||||
input_browser_content_key="b",
|
||||
output_key="c",
|
||||
)
|
||||
output = nat_bot_chain.run("foo", "foo")
|
||||
output = nat_bot_chain.execute("foo", "foo")
|
||||
assert output == "bar"
|
||||
|
@ -92,18 +92,6 @@ def test_react_chain() -> None:
|
||||
inputs = {"question": "when was langchain made"}
|
||||
output = react_chain(inputs)
|
||||
assert output["answer"] == "2022"
|
||||
expected_full_output = (
|
||||
"when was langchain made\n"
|
||||
"Thought 1:I should probably search\n"
|
||||
"Action 1: Search[langchain]\n"
|
||||
"Observation 1: This is a page about LangChain.\n"
|
||||
"Thought 2:I should probably lookup\n"
|
||||
"Action 2: Lookup[made]\n"
|
||||
"Observation 2: (Result 1/1) Made in 2022.\n"
|
||||
"Thought 3:Ah okay now I know the answer\n"
|
||||
"Action 3: Finish[2022]"
|
||||
)
|
||||
assert output["full_logic"] == expected_full_output
|
||||
|
||||
|
||||
def test_react_chain_bad_action() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user