From 175a248506f75a06ab4b63c88e2da11561e01662 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Fri, 30 Dec 2022 13:55:30 -0500 Subject: [PATCH] Harrison/get rid of prints (#490) deprecate all prints in favor of callback_manager.on_text (open to better naming) --- langchain/agents/agent.py | 2 +- langchain/callbacks/base.py | 6 +++--- langchain/callbacks/shared.py | 4 ++-- langchain/callbacks/stdout.py | 10 +++++++--- langchain/chains/api/base.py | 5 ++--- langchain/chains/llm_bash/base.py | 9 ++++----- langchain/chains/llm_math/base.py | 9 ++++----- langchain/chains/pal/base.py | 3 +-- langchain/chains/sequential.py | 6 ++++-- langchain/chains/sql_database/base.py | 13 ++++++------- tests/unit_tests/callbacks/fake_callback_handler.py | 2 +- 11 files changed, 35 insertions(+), 34 deletions(-) diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 62ea2e983b5..6a2c0ea1fdc 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -218,7 +218,7 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - self.callback_manager.on_agent_end(output.log, color="green") + self.callback_manager.on_text(output.log, color="green") final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index fa036016708..b8709cce194 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -55,7 +55,7 @@ class BaseCallbackHandler(ABC): """Run when tool errors.""" @abstractmethod - def on_agent_end(self, log: str, **kwargs: Any) -> None: + def on_text(self, text: str, **kwargs: Any) -> None: """Run when agent ends.""" @@ -132,10 +132,10 @@ class CallbackManager(BaseCallbackManager): for handler in self.handlers: handler.on_tool_error(error) - def on_agent_end(self, log: str, **kwargs: Any) -> None: + def on_text(self, text: str, **kwargs: Any) -> None: """Run when agent ends.""" for handler in self.handlers: - handler.on_agent_end(log, **kwargs) + handler.on_text(text, **kwargs) def add_handler(self, handler: BaseCallbackHandler) -> None: """Add a handler to the callback manager.""" diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index 006df55a824..e255bd34417 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -88,10 +88,10 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): with self._lock: self._callback_manager.on_tool_error(error) - def on_agent_end(self, log: str, **kwargs: Any) -> None: + def on_text(self, text: str, **kwargs: Any) -> None: """Run when agent ends.""" with self._lock: - self._callback_manager.on_agent_end(log, **kwargs) + self._callback_manager.on_text(text, **kwargs) def add_handler(self, callback: BaseCallbackHandler) -> None: """Add a callback to the callback manager.""" diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index ee72c8cc7ad..d651a6f53b4 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -67,8 +67,12 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_agent_end( - self, log: str, color: Optional[str] = None, **kwargs: Any + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], ) -> None: """Run when agent ends.""" - print_text(log, color=color) + print_text(text, color=color, end=end) diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 07591bd425e..4bb81879a1f 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -8,7 +8,6 @@ from pydantic import BaseModel, root_validator from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.requests import RequestsWrapper @@ -67,10 +66,10 @@ class APIChain(Chain, BaseModel): question=question, api_docs=self.api_docs ) if self.verbose: - print_text(api_url, color="green", end="\n") + self.callback_manager.on_text(api_url, color="green", end="\n") api_response = self.requests_wrapper.run(api_url) if self.verbose: - print_text(api_response, color="yellow", end="\n") + self.callback_manager.on_text(api_response, color="yellow", end="\n") answer = self.api_answer_chain.predict( question=question, api_docs=self.api_docs, diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index 76963713023..5d9d8b610c3 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.utilities.bash import BashProcess @@ -52,11 +51,11 @@ class LLMBashChain(Chain, BaseModel): llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) bash_executor = BashProcess() if self.verbose: - print_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key]) if self.verbose: - print_text(t, color="green") + self.callback_manager.on_text(t, color="green") t = t.strip() if t.startswith("```bash"): @@ -69,8 +68,8 @@ class LLMBashChain(Chain, BaseModel): output = bash_executor.run(command_list) if self.verbose: - print_text("\nAnswer: ") - print_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ") + self.callback_manager.on_text(output, color="yellow") else: raise ValueError(f"unknown format from LLM: {t}") diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index a0485e9e3ac..383bcac96f3 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.python import PythonREPL @@ -52,17 +51,17 @@ class LLMMathChain(Chain, BaseModel): llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) python_executor = PythonREPL() if self.verbose: - print_text(inputs[self.input_key]) + self.callback_manager.on_text(inputs[self.input_key]) t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"]) if self.verbose: - print_text(t, color="green") + self.callback_manager.on_text(t, color="green") t = t.strip() if t.startswith("```python"): code = t[9:-4] output = python_executor.run(code) if self.verbose: - print_text("\nAnswer: ") - print_text(output, color="yellow") + self.callback_manager.on_text("\nAnswer: ") + self.callback_manager.on_text(output, color="yellow") answer = "Answer: " + output elif t.startswith("Answer:"): answer = t diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 08335e2fb27..858f6a1f292 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -12,7 +12,6 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.python import PythonREPL @@ -53,7 +52,7 @@ class PALChain(Chain, BaseModel): llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) code = llm_chain.predict(stop=[self.stop], **inputs) if self.verbose: - print_text(code, color="green", end="\n") + self.callback_manager.on_text(code, color="green", end="\n") repl = PythonREPL() res = repl.run(code + f"\n{self.get_answer_expr}") return {self.output_key: res.strip()} diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index 1bcc723d40d..9db4be411e3 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -5,7 +5,7 @@ from typing import Dict, List from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain -from langchain.input import get_color_mapping, print_text +from langchain.input import get_color_mapping class SequentialChain(Chain, BaseModel): @@ -133,5 +133,7 @@ class SimpleSequentialChain(Chain, BaseModel): if self.strip_outputs: _input = _input.strip() if self.verbose: - print_text(_input, color=color_mapping[str(i)], end="\n") + self.callback_manager.on_text( + _input, color=color_mapping[str(i)], end="\n" + ) return {self.output_key: _input} diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index 103993f35dd..3c56cac7d69 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -6,7 +6,6 @@ from pydantic import BaseModel, Extra from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import PROMPT -from langchain.input import print_text from langchain.llms.base import BaseLLM from langchain.sql_database import SQLDatabase @@ -55,7 +54,7 @@ class SQLDatabaseChain(Chain, BaseModel): llm_chain = LLMChain(llm=self.llm, prompt=PROMPT) input_text = f"{inputs[self.input_key]} \nSQLQuery:" if self.verbose: - print_text(input_text) + self.callback_manager.on_text(input_text) llm_inputs = { "input": input_text, "dialect": self.database.dialect, @@ -64,15 +63,15 @@ class SQLDatabaseChain(Chain, BaseModel): } sql_cmd = llm_chain.predict(**llm_inputs) if self.verbose: - print_text(sql_cmd, color="green") + self.callback_manager.on_text(sql_cmd, color="green") result = self.database.run(sql_cmd) if self.verbose: - print_text("\nSQLResult: ") - print_text(result, color="yellow") - print_text("\nAnswer:") + self.callback_manager.on_text("\nSQLResult: ") + self.callback_manager.on_text(result, color="yellow") + self.callback_manager.on_text("\nAnswer:") input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text final_result = llm_chain.predict(**llm_inputs) if self.verbose: - print_text(final_result, color="green") + self.callback_manager.on_text(final_result, color="green") return {self.output_key: final_result} diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 18f88d1f2fd..fc62aa3212b 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -59,6 +59,6 @@ class FakeCallbackHandler(BaseCallbackHandler): """Run when tool errors.""" self.errors += 1 - def on_agent_end(self, log: str, **kwargs: Any) -> None: + def on_text(self, text: str, **kwargs: Any) -> None: """Run when agent is ending.""" self.ends += 1