diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 2c95b2edca7..9169b37585a 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, root_validator import langchain from langchain.agents.tools import Tool -from langchain.callbacks import get_callback_manager +from langchain.callbacks.base import BaseCallbackManager from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping @@ -133,10 +133,17 @@ class Agent(BaseModel): pass @classmethod - def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent: + def from_llm_and_tools( + cls, + llm: BaseLLM, + tools: List[Tool], + callback_manager: Optional[BaseCallbackManager] = None, + ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) - llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) + llm_chain = LLMChain( + llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager + ) return cls(llm_chain=llm_chain) def return_stopped_response(self) -> dict: @@ -154,10 +161,16 @@ class AgentExecutor(Chain, BaseModel): @classmethod def from_agent_and_tools( - cls, agent: Agent, tools: List[Tool], **kwargs: Any + cls, + agent: Agent, + tools: List[Tool], + callback_manager: Optional[BaseCallbackManager] = None, + **kwargs: Any, ) -> AgentExecutor: """Create from agent and tools.""" - return cls(agent=agent, tools=tools, **kwargs) + return cls( + agent=agent, tools=tools, callback_manager=callback_manager, **kwargs + ) @property def input_keys(self) -> List[str]: @@ -214,12 +227,12 @@ class AgentExecutor(Chain, BaseModel): # And then we lookup the tool if output.tool in name_to_tool_map: chain = name_to_tool_map[output.tool] - get_callback_manager().on_tool_start( + self._get_callback_manager().on_tool_start( {"name": str(chain)[:60] + "..."}, output.tool, output.tool_input ) # We then call the tool on the tool input to get an observation observation = chain(output.tool_input) - get_callback_manager().on_tool_end(observation) + self._get_callback_manager().on_tool_end(observation) color = color_mapping[output.tool] else: observation = f"{output.tool} is not a valid tool, try another one." diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index f1823b74589..3f3e4c81d3c 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,11 +1,12 @@ """Load agent.""" -from typing import Any, List +from typing import Any, List, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool +from langchain.callbacks.base import BaseCallbackManager from langchain.llms.base import BaseLLM AGENT_TO_CLASS = { @@ -19,6 +20,7 @@ def initialize_agent( tools: List[Tool], llm: BaseLLM, agent: str = "zero-shot-react-description", + callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> AgentExecutor: """Load agent given tools and LLM. @@ -28,6 +30,8 @@ def initialize_agent( llm: Language model to use as the agent. agent: The agent to use. Valid options are: `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. + callback_manager: CallbackManager to use. Global callback manager is used if + not provided. Defaults to None. **kwargs: Additional key word arguments to pass to the agent. Returns: @@ -39,5 +43,9 @@ def initialize_agent( f"Valid types are: {AGENT_TO_CLASS.keys()}." ) agent_cls = AGENT_TO_CLASS[agent] - agent_obj = agent_cls.from_llm_and_tools(llm, tools) - return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs) + agent_obj = agent_cls.from_llm_and_tools( + llm, tools, callback_manager=callback_manager + ) + return AgentExecutor.from_agent_and_tools( + agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs + ) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index c4ce68d6236..ce28bf841ab 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Extra, Field import langchain from langchain.callbacks import get_callback_manager +from langchain.callbacks.base import BaseCallbackManager class Memory(BaseModel, ABC): @@ -43,9 +44,21 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[Memory] = None + callback_manager: Optional[BaseCallbackManager] = None + verbose: bool = Field( + default_factory=_get_verbosity + ) # Whether to print the response text - verbose: bool = Field(default_factory=_get_verbosity) - """Whether to print out response text.""" + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def _get_callback_manager(self) -> BaseCallbackManager: + """Get the callback manager.""" + if self.callback_manager is not None: + return self.callback_manager + return get_callback_manager() @property @abstractmethod @@ -110,9 +123,11 @@ class Chain(BaseModel, ABC): print( f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" ) - get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs) + self._get_callback_manager().on_chain_start( + {"name": self.__class__.__name__}, inputs + ) outputs = self._call(inputs) - get_callback_manager().on_chain_end(outputs) + self._get_callback_manager().on_chain_end(outputs) if self.verbose: print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") self._validate_outputs(outputs) diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 1a1ca392259..06df99e78cf 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra import langchain from langchain.callbacks import get_callback_manager +from langchain.callbacks.base import BaseCallbackManager from langchain.schema import Generation, LLMResult @@ -16,11 +17,13 @@ class BaseLLM(BaseModel, ABC): """LLM wrapper should take in a prompt and return a string.""" cache: Optional[bool] = None + callback_manager: Optional[BaseCallbackManager] = None class Config: """Configuration for this pydantic object.""" extra = Extra.forbid + arbitrary_types_allowed = True @abstractmethod def _generate( @@ -28,6 +31,12 @@ class BaseLLM(BaseModel, ABC): ) -> LLMResult: """Run the LLM on the given prompts.""" + def _get_callback_manager(self) -> BaseCallbackManager: + """Get the callback manager.""" + if self.callback_manager is not None: + return self.callback_manager + return get_callback_manager() + def generate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: @@ -39,11 +48,11 @@ class BaseLLM(BaseModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - get_callback_manager().on_llm_start( + self._get_callback_manager().on_llm_start( {"name": self.__class__.__name__}, prompts ) output = self._generate(prompts, stop=stop) - get_callback_manager().on_llm_end(output) + self._get_callback_manager().on_llm_end(output) return output params = self._llm_dict() params["stop"] = stop @@ -58,11 +67,11 @@ class BaseLLM(BaseModel, ABC): else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) - get_callback_manager().on_llm_start( + self._get_callback_manager().on_llm_start( {"name": self.__class__.__name__}, missing_prompts ) new_results = self._generate(missing_prompts, stop=stop) - get_callback_manager().on_llm_end(new_results) + self._get_callback_manager().on_llm_end(new_results) for i, result in enumerate(new_results.generations): existing_prompts[i] = result prompt = prompts[i] diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 8d1dfdb3d68..002c66fdb9e 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional from pydantic import BaseModel -from langchain.agents import Tool, initialize_agent +from langchain.agents import AgentExecutor, Tool, initialize_agent +from langchain.callbacks.base import CallbackManager from langchain.llms.base import LLM +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler class FakeListLLM(LLM, BaseModel): @@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel): return "fake_list" -def test_agent_bad_action() -> None: - """Test react chain when bad action given.""" +def _get_agent(**kwargs: Any) -> AgentExecutor: + """Get agent for testing.""" bad_action_name = "BadAction" responses = [ f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", @@ -44,30 +46,50 @@ def test_agent_bad_action() -> None: Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), ] agent = initialize_agent( - tools, fake_llm, agent="zero-shot-react-description", verbose=True + tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs ) + return agent + + +def test_agent_bad_action() -> None: + """Test react chain when bad action given.""" + agent = _get_agent() output = agent.run("when was langchain made") assert output == "curses foiled again" def test_agent_stopped_early() -> None: """Test react chain when bad action given.""" - bad_action_name = "BadAction" + agent = _get_agent(max_iterations=0) + output = agent.run("when was langchain made") + assert output == "Agent stopped due to max iterations." + + +def test_agent_with_callbacks() -> None: + """Test react chain with callbacks.""" + handler = FakeCallbackHandler() + manager = CallbackManager([handler]) + tool = "Search" responses = [ - f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", + f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nAction: Final Answer\nAction Input: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses) + fake_llm = FakeListLLM(responses=responses, callback_manager=manager) tools = [ Tool("Search", lambda x: x, "Useful for searching"), - Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), ] agent = initialize_agent( tools, fake_llm, agent="zero-shot-react-description", verbose=True, - max_iterations=0, + callback_manager=manager, ) + output = agent.run("when was langchain made") - assert output == "Agent stopped due to max iterations." + assert output == "curses foiled again" + + # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run + assert handler.starts == 6 + assert handler.ends == 6 + assert handler.errors == 0 diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py new file mode 100644 index 00000000000..3e50a296d0c --- /dev/null +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -0,0 +1,60 @@ +"""A fake callback handler for testing purposes.""" +from typing import Any, Dict, List + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import LLMResult + + +class FakeCallbackHandler(BaseCallbackHandler): + """Fake callback handler for testing.""" + + def __init__(self) -> None: + """Initialize the mock callback handler.""" + self.starts = 0 + self.ends = 0 + self.errors = 0 + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **extra: str + ) -> None: + """Run when LLM starts running.""" + self.starts += 1 + + def on_llm_end( + self, + response: LLMResult, + ) -> None: + """Run when LLM ends running.""" + self.ends += 1 + + def on_llm_error(self, error: Exception) -> None: + """Run when LLM errors.""" + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str + ) -> None: + """Run when chain starts running.""" + self.starts += 1 + + def on_chain_end(self, outputs: Dict[str, Any]) -> None: + """Run when chain ends running.""" + self.ends += 1 + + def on_chain_error(self, error: Exception) -> None: + """Run when chain errors.""" + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str + ) -> None: + """Run when tool starts running.""" + self.starts += 1 + + def on_tool_end(self, output: str) -> None: + """Run when tool ends running.""" + self.ends += 1 + + def on_tool_error(self, error: Exception) -> None: + """Run when tool errors.""" + self.errors += 1 diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index fcab8dda954..2ddf36efb5a 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,69 +1,9 @@ """Test CallbackManager.""" -from typing import Any, Dict, List - -from langchain.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, - CallbackManager, -) +from langchain.callbacks.base import BaseCallbackManager, CallbackManager from langchain.callbacks.shared import SharedCallbackManager from langchain.schema import LLMResult - - -class FakeCallbackHandler(BaseCallbackHandler): - """Fake callback handler for testing.""" - - def __init__(self) -> None: - """Initialize the mock callback handler.""" - self.starts = 0 - self.ends = 0 - self.errors = 0 - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **extra: str - ) -> None: - """Run when LLM starts running.""" - self.starts += 1 - - def on_llm_end( - self, - response: LLMResult, - ) -> None: - """Run when LLM ends running.""" - self.ends += 1 - - def on_llm_error(self, error: Exception) -> None: - """Run when LLM errors.""" - self.errors += 1 - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str - ) -> None: - """Run when chain starts running.""" - self.starts += 1 - - def on_chain_end(self, outputs: Dict[str, Any]) -> None: - """Run when chain ends running.""" - self.ends += 1 - - def on_chain_error(self, error: Exception) -> None: - """Run when chain errors.""" - self.errors += 1 - - def on_tool_start( - self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str - ) -> None: - """Run when tool starts running.""" - self.starts += 1 - - def on_tool_end(self, output: str) -> None: - """Run when tool ends running.""" - self.ends += 1 - - def on_tool_error(self, error: Exception) -> None: - """Run when tool errors.""" - self.errors += 1 +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler def _test_callback_manager( diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 69dfd6bf7cc..9a677adfc4c 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -4,7 +4,9 @@ from typing import Any, Dict, List import pytest from pydantic import BaseModel +from langchain.callbacks.base import CallbackManager from langchain.chains.base import Chain, Memory +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler class FakeMemory(Memory, BaseModel): @@ -133,3 +135,14 @@ def test_run_arg_with_memory() -> None: """Test run method works when arg is passed.""" chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory()) chain.run("bar") + + +def test_run_with_callback() -> None: + """Test run method works when callback manager is passed.""" + handler = FakeCallbackHandler() + chain = FakeChain(callback_manager=CallbackManager([handler])) + output = chain.run("bar") + assert output == "baz" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py new file mode 100644 index 00000000000..9587f5fb8be --- /dev/null +++ b/tests/unit_tests/llms/test_callbacks.py @@ -0,0 +1,15 @@ +"""Test LLM callbacks.""" +from langchain.callbacks.base import CallbackManager +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from tests.unit_tests.llms.fake_llm import FakeLLM + + +def test_llm_with_callbacks() -> None: + """Test LLM callbacks.""" + handler = FakeCallbackHandler() + llm = FakeLLM(callback_manager=CallbackManager([handler])) + output = llm("foo") + assert output == "foo" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0