From 5d43246694b4ef591945866ba8598fc200667288 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 29 Dec 2022 21:34:47 -0500 Subject: [PATCH] WIP: stdout callback (#479) first pass at stdout callback for the most part, went pretty smoothly. aside from the code here, here are some comments/observations. 1. should somehow default to stdouthandler so i dont have to do ``` from langchain.callbacks import get_callback_manager from langchain.callbacks.stdout import StdOutCallbackHandler get_callback_manager().add_handler(StdOutCallbackHandler()) ``` 2. I kept around the verbosity flag. 1) this is pretty important for getting the stdout to look good for agents (and other things). 2) I actually added this for LLM class since it didn't have it. 3. The only part that isn't basically perfectly moved over is the end of the agent run. Here's a screenshot of the new stdout tracing ![Screen Shot 2022-12-29 at 4 03 50 PM](https://user-images.githubusercontent.com/11986836/210011538-6a74551a-2e61-437b-98d3-674212dede56.png) Noticing it is missing logging of the final thought, eg before this is what it looked like ![Screen Shot 2022-12-29 at 4 13 07 PM](https://user-images.githubusercontent.com/11986836/210011635-de68b3f5-e2b0-4cd3-9f1a-3afe970a8716.png) The reason its missing is that this was previously logged as part of agent end (lines 205 and 206) this is probably only relevant for the std out logger? any thoughts for how to get it back in? --- langchain/__init__.py | 2 + langchain/agents/agent.py | 37 ++++++---- langchain/agents/loading.py | 10 ++- langchain/callbacks/__init__.py | 7 ++ langchain/callbacks/base.py | 26 +++---- langchain/callbacks/shared.py | 20 +++--- langchain/callbacks/stdout.py | 69 +++++++++++++++++++ langchain/chains/base.py | 10 +-- langchain/chains/llm.py | 6 +- langchain/llms/base.py | 18 +++-- langchain/schema.py | 10 ++- tests/unit_tests/agents/test_agent.py | 37 ++++++++-- .../callbacks/fake_callback_handler.py | 10 +-- .../callbacks/test_callback_manager.py | 6 +- tests/unit_tests/chains/test_base.py | 13 +++- tests/unit_tests/llms/test_callbacks.py | 13 +++- 16 files changed, 220 insertions(+), 74 deletions(-) create mode 100644 langchain/callbacks/stdout.py diff --git a/langchain/__init__.py b/langchain/__init__.py index c100d0e30e3..0a607890d6a 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -4,6 +4,7 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache +from langchain.callbacks import set_default_callback_manager from langchain.chains import ( ConversationChain, LLMBashChain, @@ -33,6 +34,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch logger: BaseLogger = StdOutLogger() verbose: bool = False llm_cache: Optional[BaseCache] = None +set_default_callback_manager() __all__ = [ "LLMChain", diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 9169b37585a..5a946527d0d 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -3,11 +3,10 @@ from __future__ import annotations import logging from abc import abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, root_validator -import langchain from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain @@ -17,7 +16,7 @@ from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AgentAction, AgentFinish +from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish logger = logging.getLogger() @@ -47,7 +46,7 @@ class Agent(BaseModel): def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any - ) -> Union[AgentFinish, AgentAction]: + ) -> AgentAction: """Given input, decided what to do. Args: @@ -74,7 +73,7 @@ class Agent(BaseModel): parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output if tool == self.finish_tool_name: - return AgentFinish({"output": tool_input}, full_output) + return AgentFinish(tool, tool_input, full_output, {"output": tool_input}) return AgentAction(tool, tool_input, full_output) def prepare_for_new_call(self) -> None: @@ -138,11 +137,15 @@ class Agent(BaseModel): llm: BaseLLM, tools: List[Tool], callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = False, ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) llm_chain = LLMChain( - llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager + llm=llm, + prompt=cls.create_prompt(tools), + callback_manager=callback_manager, + verbose=verbose, ) return cls(llm_chain=llm_chain) @@ -217,28 +220,34 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - langchain.logger.log_agent_end(output, color="green") + self._get_callback_manager().on_tool_start( + {"name": "Finish"}, output, color="green" + ) + self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output - if self.verbose: - langchain.logger.log_agent_action(output, color="green") + # And then we lookup the tool if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] - self._get_callback_manager().on_tool_start( - {"name": str(chain)[:60] + "..."}, output.tool, output.tool_input - ) + if self.verbose: + self._get_callback_manager().on_tool_start( + {"name": str(chain)[:60] + "..."}, output, color="green" + ) # We then call the tool on the tool input to get an observation observation = chain(output.tool_input) - self._get_callback_manager().on_tool_end(observation) color = color_mapping[output.tool] else: + if self.verbose: + self._get_callback_manager().on_tool_start( + {"name": "N/A"}, output, color="green" + ) observation = f"{output.tool} is not a valid tool, try another one." color = None if self.verbose: - langchain.logger.log_agent_observation( + self._get_callback_manager().on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 3f3e4c81d3c..8a49ccb0a63 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -21,6 +21,7 @@ def initialize_agent( llm: BaseLLM, agent: str = "zero-shot-react-description", callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = False, **kwargs: Any, ) -> AgentExecutor: """Load agent given tools and LLM. @@ -30,6 +31,7 @@ def initialize_agent( llm: Language model to use as the agent. agent: The agent to use. Valid options are: `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. + verbose: Whether to use the callback manager for this particular agent. callback_manager: CallbackManager to use. Global callback manager is used if not provided. Defaults to None. **kwargs: Additional key word arguments to pass to the agent. @@ -44,8 +46,12 @@ def initialize_agent( ) agent_cls = AGENT_TO_CLASS[agent] agent_obj = agent_cls.from_llm_and_tools( - llm, tools, callback_manager=callback_manager + llm, tools, callback_manager=callback_manager, verbose=verbose ) return AgentExecutor.from_agent_and_tools( - agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs + agent=agent_obj, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **kwargs, ) diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index eef25d7e52e..9c9fb210232 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,8 +1,15 @@ """Callback handlers that allow listening to events in LangChain.""" from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.shared import SharedCallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler def get_callback_manager() -> BaseCallbackManager: """Return the shared callback manager.""" return SharedCallbackManager() + + +def set_default_callback_manager() -> None: + """Set default callback manager.""" + callback = get_callback_manager() + callback.add_handler(StdOutCallbackHandler()) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 7ce41f3ab51..ca0539b1ff2 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class BaseCallbackHandler(ABC): @@ -11,7 +11,7 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" @@ -28,7 +28,7 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" @@ -42,12 +42,12 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" @abstractmethod - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" @abstractmethod @@ -75,11 +75,11 @@ class CallbackManager(BaseCallbackManager): self.handlers = handlers def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" for handler in self.handlers: - handler.on_llm_start(serialized, prompts, **extra) + handler.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( self, @@ -95,11 +95,11 @@ class CallbackManager(BaseCallbackManager): handler.on_llm_error(error) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" for handler in self.handlers: - handler.on_chain_start(serialized, inputs, **extra) + handler.on_chain_start(serialized, inputs, **kwargs) def on_chain_end(self, outputs: Dict[str, Any]) -> None: """Run when chain ends running.""" @@ -112,16 +112,16 @@ class CallbackManager(BaseCallbackManager): handler.on_chain_error(error) def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" for handler in self.handlers: - handler.on_tool_start(serialized, action, tool_input, **extra) + handler.on_tool_start(serialized, action, **kwargs) - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" for handler in self.handlers: - handler.on_tool_end(output) + handler.on_tool_end(output, **kwargs) def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index b285a394714..10225c7769d 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -8,7 +8,7 @@ from langchain.callbacks.base import ( BaseCallbackManager, CallbackManager, ) -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class Singleton: @@ -35,11 +35,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): _callback_manager: CallbackManager = CallbackManager([]) def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" with self._lock: - self._callback_manager.on_llm_start(serialized, prompts, **extra) + self._callback_manager.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( self, @@ -55,11 +55,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): self._callback_manager.on_llm_error(error) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" with self._lock: - self._callback_manager.on_chain_start(serialized, inputs, **extra) + self._callback_manager.on_chain_start(serialized, inputs, **kwargs) def on_chain_end(self, outputs: Dict[str, Any]) -> None: """Run when chain ends running.""" @@ -72,18 +72,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): self._callback_manager.on_chain_error(error) def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" with self._lock: - self._callback_manager.on_tool_start( - serialized, action, tool_input, **extra - ) + self._callback_manager.on_tool_start(serialized, action, **kwargs) - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" with self._lock: - self._callback_manager.on_tool_end(output) + self._callback_manager.on_tool_end(output, **kwargs) def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py new file mode 100644 index 00000000000..f16a287a366 --- /dev/null +++ b/langchain/callbacks/stdout.py @@ -0,0 +1,69 @@ +"""Callback Handler that prints to std out.""" +from typing import Any, Dict, List, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.input import print_text +from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult + + +class StdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + print("Prompts after formatting:") + for prompt in prompts: + print_text(prompt, color="green", end="\n") + + def on_llm_end(self, response: LLMResult) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Print out that we finished a chain.""" + print("\n\033[1m> Finished chain.\033[0m") + + def on_chain_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + color: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Print out the log in specified color.""" + print_text(action.log, color=color) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + if output != AGENT_FINISH_OBSERVATION: + print_text(f"\n{observation_prefix}") + print_text(output, color=color) + print_text(f"\n{llm_prefix}") + + def on_tool_error(self, error: Exception) -> None: + """Do nothing.""" + pass diff --git a/langchain/chains/base.py b/langchain/chains/base.py index ce28bf841ab..87bc0f1baef 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -120,16 +120,12 @@ class Chain(BaseModel, ABC): inputs = dict(inputs, **external_context) self._validate_inputs(inputs) if self.verbose: - print( - f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" + self._get_callback_manager().on_chain_start( + {"name": self.__class__.__name__}, inputs ) - self._get_callback_manager().on_chain_start( - {"name": self.__class__.__name__}, inputs - ) outputs = self._call(inputs) - self._get_callback_manager().on_chain_end(outputs) if self.verbose: - print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") + self._get_callback_manager().on_chain_end(outputs) self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 72c1f3a03c0..1cd191b7132 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Sequence, Union from pydantic import BaseModel, Extra -import langchain from langchain.chains.base import Chain from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -54,6 +53,7 @@ class LLMChain(Chain, BaseModel): def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: """Generate LLM result from inputs.""" + self.llm.verbose = self.verbose stop = None if "stop" in input_list[0]: stop = input_list[0]["stop"] @@ -61,8 +61,6 @@ class LLMChain(Chain, BaseModel): for inputs in input_list: selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} prompt = self.prompt.format(**selected_inputs) - if self.verbose: - langchain.logger.log_llm_inputs(selected_inputs, prompt) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -78,8 +76,6 @@ class LLMChain(Chain, BaseModel): for generation in response.generations: # Get the text of the top generated string. response_str = generation[0].text - if self.verbose: - langchain.logger.log_llm_response(response_str) outputs.append({self.output_key: response_str}) return outputs diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 06df99e78cf..8b50cfcaf9b 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Union import yaml -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field import langchain from langchain.callbacks import get_callback_manager @@ -13,10 +13,16 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.schema import Generation, LLMResult +def _get_verbosity() -> bool: + return langchain.verbose + + class BaseLLM(BaseModel, ABC): """LLM wrapper should take in a prompt and return a string.""" cache: Optional[bool] = None + verbose: bool = Field(default_factory=_get_verbosity) + """Whether to print out response text.""" callback_manager: Optional[BaseCallbackManager] = None class Config: @@ -48,11 +54,13 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - self._get_callback_manager().on_llm_start( - {"name": self.__class__.__name__}, prompts - ) + if self.verbose: + self._get_callback_manager().on_llm_start( + {"name": self.__class__.__name__}, prompts + ) output = self._generate(prompts, stop=stop) - self._get_callback_manager().on_llm_end(output) + if self.verbose: + self._get_callback_manager().on_llm_end(output) return output params = self._llm_dict() params["stop"] = stop diff --git a/langchain/schema.py b/langchain/schema.py index a4b4e6267ce..ce2811fded4 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,9 +1,13 @@ """Common schema objects.""" +from dataclasses import dataclass from typing import List, NamedTuple, Optional +AGENT_FINISH_OBSERVATION = "__agent_finish__" -class AgentAction(NamedTuple): + +@dataclass +class AgentAction: """Agent's action to take.""" tool: str @@ -11,11 +15,11 @@ class AgentAction(NamedTuple): log: str -class AgentFinish(NamedTuple): +@dataclass +class AgentFinish(AgentAction): """Agent's return value.""" return_values: dict - log: str class Generation(NamedTuple): diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 002c66fdb9e..038fd154835 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -74,7 +74,7 @@ def test_agent_with_callbacks() -> None: f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nAction: Final Answer\nAction Input: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager) + fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) tools = [ Tool("Search", lambda x: x, "Useful for searching"), ] @@ -89,7 +89,36 @@ def test_agent_with_callbacks() -> None: output = agent.run("when was langchain made") assert output == "curses foiled again" - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.starts == 6 - assert handler.ends == 6 + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending + assert handler.starts == 7 + assert handler.ends == 7 + assert handler.errors == 0 + + +def test_agent_with_callbacks_not_verbose() -> None: + """Test react chain with callbacks but not verbose.""" + handler = FakeCallbackHandler() + manager = CallbackManager([handler]) + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses, callback_manager=manager) + tools = [ + Tool("Search", lambda x: x, "Useful for searching"), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + callback_manager=manager, + ) + + output = agent.run("when was langchain made") + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 0 + assert handler.ends == 0 assert handler.errors == 0 diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 3e50a296d0c..5408042f5f1 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class FakeCallbackHandler(BaseCallbackHandler): @@ -15,7 +15,7 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors = 0 def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" self.starts += 1 @@ -32,7 +32,7 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors += 1 def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" self.starts += 1 @@ -46,12 +46,12 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors += 1 def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" self.starts += 1 - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" self.ends += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 2ddf36efb5a..915bf5808ae 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -2,7 +2,7 @@ from langchain.callbacks.base import BaseCallbackManager, CallbackManager from langchain.callbacks.shared import SharedCallbackManager -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -13,10 +13,10 @@ def _test_callback_manager( manager.on_llm_start({}, []) manager.on_llm_end(LLMResult(generations=[])) manager.on_llm_error(Exception()) - manager.on_chain_start({}, {}) + manager.on_chain_start({"name": "foo"}, {}) manager.on_chain_end({}) manager.on_chain_error(Exception()) - manager.on_tool_start({}, "", "") + manager.on_tool_start({}, AgentAction("", "", "")) manager.on_tool_end("") manager.on_tool_error(Exception()) for handler in handlers: diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 9a677adfc4c..3c17a516aee 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -140,9 +140,20 @@ def test_run_arg_with_memory() -> None: def test_run_with_callback() -> None: """Test run method works when callback manager is passed.""" handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager([handler])) + chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True) output = chain.run("bar") assert output == "baz" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_run_with_callback_not_verbose() -> None: + """Test run method works when callback manager is passed and not verbose.""" + handler = FakeCallbackHandler() + chain = FakeChain(callback_manager=CallbackManager([handler])) + output = chain.run("bar") + assert output == "baz" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 9587f5fb8be..2c3b45294b3 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -7,9 +7,20 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager([handler])) + llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True) output = llm("foo") assert output == "foo" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_llm_with_callbacks_not_verbose() -> None: + """Test LLM callbacks but not verbose.""" + handler = FakeCallbackHandler() + llm = FakeLLM(callback_manager=CallbackManager([handler])) + output = llm("foo") + assert output == "foo" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0