diff --git a/langchain/__init__.py b/langchain/__init__.py index c100d0e30e3..0a607890d6a 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -4,6 +4,7 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache +from langchain.callbacks import set_default_callback_manager from langchain.chains import ( ConversationChain, LLMBashChain, @@ -33,6 +34,7 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch logger: BaseLogger = StdOutLogger() verbose: bool = False llm_cache: Optional[BaseCache] = None +set_default_callback_manager() __all__ = [ "LLMChain", diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 9169b37585a..5a946527d0d 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -3,11 +3,10 @@ from __future__ import annotations import logging from abc import abstractmethod -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, root_validator -import langchain from langchain.agents.tools import Tool from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain @@ -17,7 +16,7 @@ from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import AgentAction, AgentFinish +from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, AgentFinish logger = logging.getLogger() @@ -47,7 +46,7 @@ class Agent(BaseModel): def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any - ) -> Union[AgentFinish, AgentAction]: + ) -> AgentAction: """Given input, decided what to do. Args: @@ -74,7 +73,7 @@ class Agent(BaseModel): parsed_output = self._extract_tool_and_input(full_output) tool, tool_input = parsed_output if tool == self.finish_tool_name: - return AgentFinish({"output": tool_input}, full_output) + return AgentFinish(tool, tool_input, full_output, {"output": tool_input}) return AgentAction(tool, tool_input, full_output) def prepare_for_new_call(self) -> None: @@ -138,11 +137,15 @@ class Agent(BaseModel): llm: BaseLLM, tools: List[Tool], callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = False, ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) llm_chain = LLMChain( - llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager + llm=llm, + prompt=cls.create_prompt(tools), + callback_manager=callback_manager, + verbose=verbose, ) return cls(llm_chain=llm_chain) @@ -217,28 +220,34 @@ class AgentExecutor(Chain, BaseModel): # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): if self.verbose: - langchain.logger.log_agent_end(output, color="green") + self._get_callback_manager().on_tool_start( + {"name": "Finish"}, output, color="green" + ) + self._get_callback_manager().on_tool_end(AGENT_FINISH_OBSERVATION) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output - if self.verbose: - langchain.logger.log_agent_action(output, color="green") + # And then we lookup the tool if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] - self._get_callback_manager().on_tool_start( - {"name": str(chain)[:60] + "..."}, output.tool, output.tool_input - ) + if self.verbose: + self._get_callback_manager().on_tool_start( + {"name": str(chain)[:60] + "..."}, output, color="green" + ) # We then call the tool on the tool input to get an observation observation = chain(output.tool_input) - self._get_callback_manager().on_tool_end(observation) color = color_mapping[output.tool] else: + if self.verbose: + self._get_callback_manager().on_tool_start( + {"name": "N/A"}, output, color="green" + ) observation = f"{output.tool} is not a valid tool, try another one." color = None if self.verbose: - langchain.logger.log_agent_observation( + self._get_callback_manager().on_tool_end( observation, color=color, observation_prefix=self.agent.observation_prefix, diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index 3f3e4c81d3c..8a49ccb0a63 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -21,6 +21,7 @@ def initialize_agent( llm: BaseLLM, agent: str = "zero-shot-react-description", callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = False, **kwargs: Any, ) -> AgentExecutor: """Load agent given tools and LLM. @@ -30,6 +31,7 @@ def initialize_agent( llm: Language model to use as the agent. agent: The agent to use. Valid options are: `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. + verbose: Whether to use the callback manager for this particular agent. callback_manager: CallbackManager to use. Global callback manager is used if not provided. Defaults to None. **kwargs: Additional key word arguments to pass to the agent. @@ -44,8 +46,12 @@ def initialize_agent( ) agent_cls = AGENT_TO_CLASS[agent] agent_obj = agent_cls.from_llm_and_tools( - llm, tools, callback_manager=callback_manager + llm, tools, callback_manager=callback_manager, verbose=verbose ) return AgentExecutor.from_agent_and_tools( - agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs + agent=agent_obj, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **kwargs, ) diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index eef25d7e52e..9c9fb210232 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,8 +1,15 @@ """Callback handlers that allow listening to events in LangChain.""" from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.shared import SharedCallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler def get_callback_manager() -> BaseCallbackManager: """Return the shared callback manager.""" return SharedCallbackManager() + + +def set_default_callback_manager() -> None: + """Set default callback manager.""" + callback = get_callback_manager() + callback.add_handler(StdOutCallbackHandler()) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 7ce41f3ab51..ca0539b1ff2 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class BaseCallbackHandler(ABC): @@ -11,7 +11,7 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" @@ -28,7 +28,7 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" @@ -42,12 +42,12 @@ class BaseCallbackHandler(ABC): @abstractmethod def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" @abstractmethod - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" @abstractmethod @@ -75,11 +75,11 @@ class CallbackManager(BaseCallbackManager): self.handlers = handlers def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" for handler in self.handlers: - handler.on_llm_start(serialized, prompts, **extra) + handler.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( self, @@ -95,11 +95,11 @@ class CallbackManager(BaseCallbackManager): handler.on_llm_error(error) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" for handler in self.handlers: - handler.on_chain_start(serialized, inputs, **extra) + handler.on_chain_start(serialized, inputs, **kwargs) def on_chain_end(self, outputs: Dict[str, Any]) -> None: """Run when chain ends running.""" @@ -112,16 +112,16 @@ class CallbackManager(BaseCallbackManager): handler.on_chain_error(error) def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" for handler in self.handlers: - handler.on_tool_start(serialized, action, tool_input, **extra) + handler.on_tool_start(serialized, action, **kwargs) - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" for handler in self.handlers: - handler.on_tool_end(output) + handler.on_tool_end(output, **kwargs) def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py index b285a394714..10225c7769d 100644 --- a/langchain/callbacks/shared.py +++ b/langchain/callbacks/shared.py @@ -8,7 +8,7 @@ from langchain.callbacks.base import ( BaseCallbackManager, CallbackManager, ) -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class Singleton: @@ -35,11 +35,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): _callback_manager: CallbackManager = CallbackManager([]) def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" with self._lock: - self._callback_manager.on_llm_start(serialized, prompts, **extra) + self._callback_manager.on_llm_start(serialized, prompts, **kwargs) def on_llm_end( self, @@ -55,11 +55,11 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): self._callback_manager.on_llm_error(error) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" with self._lock: - self._callback_manager.on_chain_start(serialized, inputs, **extra) + self._callback_manager.on_chain_start(serialized, inputs, **kwargs) def on_chain_end(self, outputs: Dict[str, Any]) -> None: """Run when chain ends running.""" @@ -72,18 +72,16 @@ class SharedCallbackManager(Singleton, BaseCallbackManager): self._callback_manager.on_chain_error(error) def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" with self._lock: - self._callback_manager.on_tool_start( - serialized, action, tool_input, **extra - ) + self._callback_manager.on_tool_start(serialized, action, **kwargs) - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" with self._lock: - self._callback_manager.on_tool_end(output) + self._callback_manager.on_tool_end(output, **kwargs) def on_tool_error(self, error: Exception) -> None: """Run when tool errors.""" diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py new file mode 100644 index 00000000000..f16a287a366 --- /dev/null +++ b/langchain/callbacks/stdout.py @@ -0,0 +1,69 @@ +"""Callback Handler that prints to std out.""" +from typing import Any, Dict, List, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.input import print_text +from langchain.schema import AGENT_FINISH_OBSERVATION, AgentAction, LLMResult + + +class StdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + print("Prompts after formatting:") + for prompt in prompts: + print_text(prompt, color="green", end="\n") + + def on_llm_end(self, response: LLMResult) -> None: + """Do nothing.""" + pass + + def on_llm_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m") + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Print out that we finished a chain.""" + print("\n\033[1m> Finished chain.\033[0m") + + def on_chain_error(self, error: Exception) -> None: + """Do nothing.""" + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + action: AgentAction, + color: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Print out the log in specified color.""" + print_text(action.log, color=color) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + if output != AGENT_FINISH_OBSERVATION: + print_text(f"\n{observation_prefix}") + print_text(output, color=color) + print_text(f"\n{llm_prefix}") + + def on_tool_error(self, error: Exception) -> None: + """Do nothing.""" + pass diff --git a/langchain/chains/base.py b/langchain/chains/base.py index ce28bf841ab..87bc0f1baef 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -120,16 +120,12 @@ class Chain(BaseModel, ABC): inputs = dict(inputs, **external_context) self._validate_inputs(inputs) if self.verbose: - print( - f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" + self._get_callback_manager().on_chain_start( + {"name": self.__class__.__name__}, inputs ) - self._get_callback_manager().on_chain_start( - {"name": self.__class__.__name__}, inputs - ) outputs = self._call(inputs) - self._get_callback_manager().on_chain_end(outputs) if self.verbose: - print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") + self._get_callback_manager().on_chain_end(outputs) self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 72c1f3a03c0..1cd191b7132 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Sequence, Union from pydantic import BaseModel, Extra -import langchain from langchain.chains.base import Chain from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate @@ -54,6 +53,7 @@ class LLMChain(Chain, BaseModel): def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: """Generate LLM result from inputs.""" + self.llm.verbose = self.verbose stop = None if "stop" in input_list[0]: stop = input_list[0]["stop"] @@ -61,8 +61,6 @@ class LLMChain(Chain, BaseModel): for inputs in input_list: selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} prompt = self.prompt.format(**selected_inputs) - if self.verbose: - langchain.logger.log_llm_inputs(selected_inputs, prompt) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -78,8 +76,6 @@ class LLMChain(Chain, BaseModel): for generation in response.generations: # Get the text of the top generated string. response_str = generation[0].text - if self.verbose: - langchain.logger.log_llm_response(response_str) outputs.append({self.output_key: response_str}) return outputs diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 06df99e78cf..8b50cfcaf9b 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Union import yaml -from pydantic import BaseModel, Extra +from pydantic import BaseModel, Extra, Field import langchain from langchain.callbacks import get_callback_manager @@ -13,10 +13,16 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.schema import Generation, LLMResult +def _get_verbosity() -> bool: + return langchain.verbose + + class BaseLLM(BaseModel, ABC): """LLM wrapper should take in a prompt and return a string.""" cache: Optional[bool] = None + verbose: bool = Field(default_factory=_get_verbosity) + """Whether to print out response text.""" callback_manager: Optional[BaseCallbackManager] = None class Config: @@ -48,11 +54,13 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - self._get_callback_manager().on_llm_start( - {"name": self.__class__.__name__}, prompts - ) + if self.verbose: + self._get_callback_manager().on_llm_start( + {"name": self.__class__.__name__}, prompts + ) output = self._generate(prompts, stop=stop) - self._get_callback_manager().on_llm_end(output) + if self.verbose: + self._get_callback_manager().on_llm_end(output) return output params = self._llm_dict() params["stop"] = stop diff --git a/langchain/schema.py b/langchain/schema.py index a4b4e6267ce..ce2811fded4 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -1,9 +1,13 @@ """Common schema objects.""" +from dataclasses import dataclass from typing import List, NamedTuple, Optional +AGENT_FINISH_OBSERVATION = "__agent_finish__" -class AgentAction(NamedTuple): + +@dataclass +class AgentAction: """Agent's action to take.""" tool: str @@ -11,11 +15,11 @@ class AgentAction(NamedTuple): log: str -class AgentFinish(NamedTuple): +@dataclass +class AgentFinish(AgentAction): """Agent's return value.""" return_values: dict - log: str class Generation(NamedTuple): diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 002c66fdb9e..038fd154835 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -74,7 +74,7 @@ def test_agent_with_callbacks() -> None: f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nAction: Final Answer\nAction Input: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager) + fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) tools = [ Tool("Search", lambda x: x, "Useful for searching"), ] @@ -89,7 +89,36 @@ def test_agent_with_callbacks() -> None: output = agent.run("when was langchain made") assert output == "curses foiled again" - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.starts == 6 - assert handler.ends == 6 + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run, 1 ending + assert handler.starts == 7 + assert handler.ends == 7 + assert handler.errors == 0 + + +def test_agent_with_callbacks_not_verbose() -> None: + """Test react chain with callbacks but not verbose.""" + handler = FakeCallbackHandler() + manager = CallbackManager([handler]) + tool = "Search" + responses = [ + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", + "Oh well\nAction: Final Answer\nAction Input: curses foiled again", + ] + fake_llm = FakeListLLM(responses=responses, callback_manager=manager) + tools = [ + Tool("Search", lambda x: x, "Useful for searching"), + ] + agent = initialize_agent( + tools, + fake_llm, + agent="zero-shot-react-description", + callback_manager=manager, + ) + + output = agent.run("when was langchain made") + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 0 + assert handler.ends == 0 assert handler.errors == 0 diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 3e50a296d0c..5408042f5f1 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult class FakeCallbackHandler(BaseCallbackHandler): @@ -15,7 +15,7 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors = 0 def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" self.starts += 1 @@ -32,7 +32,7 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors += 1 def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """Run when chain starts running.""" self.starts += 1 @@ -46,12 +46,12 @@ class FakeCallbackHandler(BaseCallbackHandler): self.errors += 1 def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + self, serialized: Dict[str, Any], action: AgentAction, **kwargs: Any ) -> None: """Run when tool starts running.""" self.starts += 1 - def on_tool_end(self, output: str) -> None: + def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" self.ends += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 2ddf36efb5a..915bf5808ae 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -2,7 +2,7 @@ from langchain.callbacks.base import BaseCallbackManager, CallbackManager from langchain.callbacks.shared import SharedCallbackManager -from langchain.schema import LLMResult +from langchain.schema import AgentAction, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -13,10 +13,10 @@ def _test_callback_manager( manager.on_llm_start({}, []) manager.on_llm_end(LLMResult(generations=[])) manager.on_llm_error(Exception()) - manager.on_chain_start({}, {}) + manager.on_chain_start({"name": "foo"}, {}) manager.on_chain_end({}) manager.on_chain_error(Exception()) - manager.on_tool_start({}, "", "") + manager.on_tool_start({}, AgentAction("", "", "")) manager.on_tool_end("") manager.on_tool_error(Exception()) for handler in handlers: diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 9a677adfc4c..3c17a516aee 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -140,9 +140,20 @@ def test_run_arg_with_memory() -> None: def test_run_with_callback() -> None: """Test run method works when callback manager is passed.""" handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager([handler])) + chain = FakeChain(callback_manager=CallbackManager([handler]), verbose=True) output = chain.run("bar") assert output == "baz" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_run_with_callback_not_verbose() -> None: + """Test run method works when callback manager is passed and not verbose.""" + handler = FakeCallbackHandler() + chain = FakeChain(callback_manager=CallbackManager([handler])) + output = chain.run("bar") + assert output == "baz" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0 diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 9587f5fb8be..2c3b45294b3 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -7,9 +7,20 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager([handler])) + llm = FakeLLM(callback_manager=CallbackManager([handler]), verbose=True) output = llm("foo") assert output == "foo" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_llm_with_callbacks_not_verbose() -> None: + """Test LLM callbacks but not verbose.""" + handler = FakeCallbackHandler() + llm = FakeLLM(callback_manager=CallbackManager([handler])) + output = llm("foo") + assert output == "foo" + assert handler.starts == 0 + assert handler.ends == 0 + assert handler.errors == 0