allow for optional CallbackManager in LLM, Chain, and Agent (#482)

This commit is contained in:
Ankush Gola 2022-12-29 20:30:31 -05:00 committed by GitHub
parent 46b31626b5
commit 36922318d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 185 additions and 90 deletions

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, root_validator
import langchain import langchain
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping from langchain.input import get_color_mapping
@ -133,10 +133,17 @@ class Agent(BaseModel):
pass pass
@classmethod @classmethod
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent: def from_llm_and_tools(
cls,
llm: BaseLLM,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
) -> Agent:
"""Construct an agent from an LLM and tools.""" """Construct an agent from an LLM and tools."""
cls._validate_tools(tools) cls._validate_tools(tools)
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools)) llm_chain = LLMChain(
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
)
return cls(llm_chain=llm_chain) return cls(llm_chain=llm_chain)
def return_stopped_response(self) -> dict: def return_stopped_response(self) -> dict:
@ -154,10 +161,16 @@ class AgentExecutor(Chain, BaseModel):
@classmethod @classmethod
def from_agent_and_tools( def from_agent_and_tools(
cls, agent: Agent, tools: List[Tool], **kwargs: Any cls,
agent: Agent,
tools: List[Tool],
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Create from agent and tools.""" """Create from agent and tools."""
return cls(agent=agent, tools=tools, **kwargs) return cls(
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -214,12 +227,12 @@ class AgentExecutor(Chain, BaseModel):
# And then we lookup the tool # And then we lookup the tool
if output.tool in name_to_tool_map: if output.tool in name_to_tool_map:
chain = name_to_tool_map[output.tool] chain = name_to_tool_map[output.tool]
get_callback_manager().on_tool_start( self._get_callback_manager().on_tool_start(
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input {"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
) )
# We then call the tool on the tool input to get an observation # We then call the tool on the tool input to get an observation
observation = chain(output.tool_input) observation = chain(output.tool_input)
get_callback_manager().on_tool_end(observation) self._get_callback_manager().on_tool_end(observation)
color = color_mapping[output.tool] color = color_mapping[output.tool]
else: else:
observation = f"{output.tool} is not a valid tool, try another one." observation = f"{output.tool} is not a valid tool, try another one."

View File

@ -1,11 +1,12 @@
"""Load agent.""" """Load agent."""
from typing import Any, List from typing import Any, List, Optional
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.react.base import ReActDocstoreAgent from langchain.agents.react.base import ReActDocstoreAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool from langchain.agents.tools import Tool
from langchain.callbacks.base import BaseCallbackManager
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
AGENT_TO_CLASS = { AGENT_TO_CLASS = {
@ -19,6 +20,7 @@ def initialize_agent(
tools: List[Tool], tools: List[Tool],
llm: BaseLLM, llm: BaseLLM,
agent: str = "zero-shot-react-description", agent: str = "zero-shot-react-description",
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any, **kwargs: Any,
) -> AgentExecutor: ) -> AgentExecutor:
"""Load agent given tools and LLM. """Load agent given tools and LLM.
@ -28,6 +30,8 @@ def initialize_agent(
llm: Language model to use as the agent. llm: Language model to use as the agent.
agent: The agent to use. Valid options are: agent: The agent to use. Valid options are:
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`. `zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
callback_manager: CallbackManager to use. Global callback manager is used if
not provided. Defaults to None.
**kwargs: Additional key word arguments to pass to the agent. **kwargs: Additional key word arguments to pass to the agent.
Returns: Returns:
@ -39,5 +43,9 @@ def initialize_agent(
f"Valid types are: {AGENT_TO_CLASS.keys()}." f"Valid types are: {AGENT_TO_CLASS.keys()}."
) )
agent_cls = AGENT_TO_CLASS[agent] agent_cls = AGENT_TO_CLASS[agent]
agent_obj = agent_cls.from_llm_and_tools(llm, tools) agent_obj = agent_cls.from_llm_and_tools(
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs) llm, tools, callback_manager=callback_manager
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
)

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel, Extra, Field
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
class Memory(BaseModel, ABC): class Memory(BaseModel, ABC):
@ -43,9 +44,21 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
memory: Optional[Memory] = None memory: Optional[Memory] = None
callback_manager: Optional[BaseCallbackManager] = None
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
verbose: bool = Field(default_factory=_get_verbosity) class Config:
"""Whether to print out response text.""" """Configuration for this pydantic object."""
arbitrary_types_allowed = True
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
@property @property
@abstractmethod @abstractmethod
@ -110,9 +123,11 @@ class Chain(BaseModel, ABC):
print( print(
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m" f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
) )
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs) self._get_callback_manager().on_chain_start(
{"name": self.__class__.__name__}, inputs
)
outputs = self._call(inputs) outputs = self._call(inputs)
get_callback_manager().on_chain_end(outputs) self._get_callback_manager().on_chain_end(outputs)
if self.verbose: if self.verbose:
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m") print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
self._validate_outputs(outputs) self._validate_outputs(outputs)

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra
import langchain import langchain
from langchain.callbacks import get_callback_manager from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
@ -16,11 +17,13 @@ class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
cache: Optional[bool] = None cache: Optional[bool] = None
callback_manager: Optional[BaseCallbackManager] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True
@abstractmethod @abstractmethod
def _generate( def _generate(
@ -28,6 +31,12 @@ class BaseLLM(BaseModel, ABC):
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
def _get_callback_manager(self) -> BaseCallbackManager:
"""Get the callback manager."""
if self.callback_manager is not None:
return self.callback_manager
return get_callback_manager()
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -39,11 +48,11 @@ class BaseLLM(BaseModel, ABC):
raise ValueError( raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
get_callback_manager().on_llm_start( self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, prompts {"name": self.__class__.__name__}, prompts
) )
output = self._generate(prompts, stop=stop) output = self._generate(prompts, stop=stop)
get_callback_manager().on_llm_end(output) self._get_callback_manager().on_llm_end(output)
return output return output
params = self._llm_dict() params = self._llm_dict()
params["stop"] = stop params["stop"] = stop
@ -58,11 +67,11 @@ class BaseLLM(BaseModel, ABC):
else: else:
missing_prompts.append(prompt) missing_prompts.append(prompt)
missing_prompt_idxs.append(i) missing_prompt_idxs.append(i)
get_callback_manager().on_llm_start( self._get_callback_manager().on_llm_start(
{"name": self.__class__.__name__}, missing_prompts {"name": self.__class__.__name__}, missing_prompts
) )
new_results = self._generate(missing_prompts, stop=stop) new_results = self._generate(missing_prompts, stop=stop)
get_callback_manager().on_llm_end(new_results) self._get_callback_manager().on_llm_end(new_results)
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[i] = result existing_prompts[i] = result
prompt = prompts[i] prompt = prompts[i]

View File

@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
from pydantic import BaseModel from pydantic import BaseModel
from langchain.agents import Tool, initialize_agent from langchain.agents import AgentExecutor, Tool, initialize_agent
from langchain.callbacks.base import CallbackManager
from langchain.llms.base import LLM from langchain.llms.base import LLM
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeListLLM(LLM, BaseModel): class FakeListLLM(LLM, BaseModel):
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
return "fake_list" return "fake_list"
def test_agent_bad_action() -> None: def _get_agent(**kwargs: Any) -> AgentExecutor:
"""Test react chain when bad action given.""" """Get agent for testing."""
bad_action_name = "BadAction" bad_action_name = "BadAction"
responses = [ responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
@ -44,30 +46,50 @@ def test_agent_bad_action() -> None:
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"), Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
] ]
agent = initialize_agent( agent = initialize_agent(
tools, fake_llm, agent="zero-shot-react-description", verbose=True tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
) )
return agent
def test_agent_bad_action() -> None:
"""Test react chain when bad action given."""
agent = _get_agent()
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "curses foiled again" assert output == "curses foiled again"
def test_agent_stopped_early() -> None: def test_agent_stopped_early() -> None:
"""Test react chain when bad action given.""" """Test react chain when bad action given."""
bad_action_name = "BadAction" agent = _get_agent(max_iterations=0)
output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations."
def test_agent_with_callbacks() -> None:
"""Test react chain with callbacks."""
handler = FakeCallbackHandler()
manager = CallbackManager([handler])
tool = "Search"
responses = [ responses = [
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment", f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
"Oh well\nAction: Final Answer\nAction Input: curses foiled again", "Oh well\nAction: Final Answer\nAction Input: curses foiled again",
] ]
fake_llm = FakeListLLM(responses=responses) fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
tools = [ tools = [
Tool("Search", lambda x: x, "Useful for searching"), Tool("Search", lambda x: x, "Useful for searching"),
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
] ]
agent = initialize_agent( agent = initialize_agent(
tools, tools,
fake_llm, fake_llm,
agent="zero-shot-react-description", agent="zero-shot-react-description",
verbose=True, verbose=True,
max_iterations=0, callback_manager=manager,
) )
output = agent.run("when was langchain made") output = agent.run("when was langchain made")
assert output == "Agent stopped due to max iterations." assert output == "curses foiled again"
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
assert handler.starts == 6
assert handler.ends == 6
assert handler.errors == 0

View File

@ -0,0 +1,60 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1

View File

@ -1,69 +1,9 @@
"""Test CallbackManager.""" """Test CallbackManager."""
from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackManager, CallbackManager
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.shared import SharedCallbackManager
from langchain.schema import LLMResult from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeCallbackHandler(BaseCallbackHandler):
"""Fake callback handler for testing."""
def __init__(self) -> None:
"""Initialize the mock callback handler."""
self.starts = 0
self.ends = 0
self.errors = 0
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
) -> None:
"""Run when LLM starts running."""
self.starts += 1
def on_llm_end(
self,
response: LLMResult,
) -> None:
"""Run when LLM ends running."""
self.ends += 1
def on_llm_error(self, error: Exception) -> None:
"""Run when LLM errors."""
self.errors += 1
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
) -> None:
"""Run when chain starts running."""
self.starts += 1
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
"""Run when chain ends running."""
self.ends += 1
def on_chain_error(self, error: Exception) -> None:
"""Run when chain errors."""
self.errors += 1
def on_tool_start(
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
) -> None:
"""Run when tool starts running."""
self.starts += 1
def on_tool_end(self, output: str) -> None:
"""Run when tool ends running."""
self.ends += 1
def on_tool_error(self, error: Exception) -> None:
"""Run when tool errors."""
self.errors += 1
def _test_callback_manager( def _test_callback_manager(

View File

@ -4,7 +4,9 @@ from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from langchain.callbacks.base import CallbackManager
from langchain.chains.base import Chain, Memory from langchain.chains.base import Chain, Memory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
class FakeMemory(Memory, BaseModel): class FakeMemory(Memory, BaseModel):
@ -133,3 +135,14 @@ def test_run_arg_with_memory() -> None:
"""Test run method works when arg is passed.""" """Test run method works when arg is passed."""
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory()) chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
chain.run("bar") chain.run("bar")
def test_run_with_callback() -> None:
"""Test run method works when callback manager is passed."""
handler = FakeCallbackHandler()
chain = FakeChain(callback_manager=CallbackManager([handler]))
output = chain.run("bar")
assert output == "baz"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0

View File

@ -0,0 +1,15 @@
"""Test LLM callbacks."""
from langchain.callbacks.base import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from tests.unit_tests.llms.fake_llm import FakeLLM
def test_llm_with_callbacks() -> None:
"""Test LLM callbacks."""
handler = FakeCallbackHandler()
llm = FakeLLM(callback_manager=CallbackManager([handler]))
output = llm("foo")
assert output == "foo"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0