diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index 8827230179f..e9d22e32c28 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -7,7 +7,7 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT from langchain.chains.python import PythonChain -from langchain.input import ChainedInput +from langchain.input import print_text from langchain.llms.base import LLM @@ -51,15 +51,18 @@ class LLMMathChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) python_executor = PythonChain() - chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose) - t = llm_executor.predict(question=chained_input.input, stop=["```output"]) - chained_input.add(t, color="green") + if self.verbose: + print_text(inputs[self.input_key]) + t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) + if self.verbose: + print_text(t, color="green") t = t.strip() if t.startswith("```python"): code = t[9:-4] output = python_executor.run(code) - chained_input.add("\nAnswer: ") - chained_input.add(output, color="yellow") + if self.verbose: + print_text("\nAnswer: ") + print_text(output, color="yellow") answer = "Answer: " + output elif t.startswith("Answer:"): answer = t diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 4a13f3938ff..7ceab5fb048 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import PROMPT -from langchain.input import ChainedInput +from langchain.input import print_text from langchain.llms.base import LLM from langchain.sql_database import SQLDatabase @@ -53,22 +53,26 @@ class SQLDatabaseChain(Chain, BaseModel): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) - chained_input = ChainedInput( - f"{inputs[self.input_key]} \nSQLQuery:", verbose=self.verbose - ) + input_text = f"{inputs[self.input_key]} \nSQLQuery:" + if self.verbose: + print_text(input_text) llm_inputs = { - "input": chained_input.input, + "input": input_text, "dialect": self.database.dialect, "table_info": self.database.table_info, "stop": ["\nSQLResult:"], } sql_cmd = llm_chain.predict(**llm_inputs) - chained_input.add(sql_cmd, color="green") + if self.verbose: + print_text(sql_cmd, color="green") result = self.database.run(sql_cmd) - chained_input.add("\nSQLResult: ") - chained_input.add(result, color="yellow") - chained_input.add("\nAnswer:") - llm_inputs["input"] = chained_input.input + if self.verbose: + print_text("\nSQLResult: ") + print_text(result, color="yellow") + print_text("\nAnswer:") + input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" + llm_inputs["input"] = input_text final_result = llm_chain.predict(**llm_inputs) - chained_input.add(final_result, color="green") + if self.verbose: + print_text(final_result, color="green") return {self.output_key: final_result} diff --git a/langchain/input.py b/langchain/input.py index ef7053ad315..782cd41a8bf 100644 --- a/langchain/input.py +++ b/langchain/input.py @@ -36,13 +36,13 @@ class ChainedInput: """Initialize with verbose flag and initial text.""" self._verbose = verbose if self._verbose: - print_text(text, None) + print_text(text, color=None) self._input = text def add(self, text: str, color: Optional[str] = None) -> None: """Add text to input, print if in verbose mode.""" if self._verbose: - print_text(text, color) + print_text(text, color=color) self._input += text @property