mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
allow for optional CallbackManager in LLM, Chain, and Agent (#482)
This commit is contained in:
parent
46b31626b5
commit
36922318d3
@ -9,7 +9,7 @@ from pydantic import BaseModel, root_validator
|
||||
|
||||
import langchain
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping
|
||||
@ -133,10 +133,17 @@ class Agent(BaseModel):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(cls, llm: BaseLLM, tools: List[Tool]) -> Agent:
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
tools: List[Tool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
llm_chain = LLMChain(llm=llm, prompt=cls.create_prompt(tools))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=cls.create_prompt(tools), callback_manager=callback_manager
|
||||
)
|
||||
return cls(llm_chain=llm_chain)
|
||||
|
||||
def return_stopped_response(self) -> dict:
|
||||
@ -154,10 +161,16 @@ class AgentExecutor(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_agent_and_tools(
|
||||
cls, agent: Agent, tools: List[Tool], **kwargs: Any
|
||||
cls,
|
||||
agent: Agent,
|
||||
tools: List[Tool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Create from agent and tools."""
|
||||
return cls(agent=agent, tools=tools, **kwargs)
|
||||
return cls(
|
||||
agent=agent, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@ -214,12 +227,12 @@ class AgentExecutor(Chain, BaseModel):
|
||||
# And then we lookup the tool
|
||||
if output.tool in name_to_tool_map:
|
||||
chain = name_to_tool_map[output.tool]
|
||||
get_callback_manager().on_tool_start(
|
||||
self._get_callback_manager().on_tool_start(
|
||||
{"name": str(chain)[:60] + "..."}, output.tool, output.tool_input
|
||||
)
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = chain(output.tool_input)
|
||||
get_callback_manager().on_tool_end(observation)
|
||||
self._get_callback_manager().on_tool_end(observation)
|
||||
color = color_mapping[output.tool]
|
||||
else:
|
||||
observation = f"{output.tool} is not a valid tool, try another one."
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Load agent."""
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.react.base import ReActDocstoreAgent
|
||||
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
AGENT_TO_CLASS = {
|
||||
@ -19,6 +20,7 @@ def initialize_agent(
|
||||
tools: List[Tool],
|
||||
llm: BaseLLM,
|
||||
agent: str = "zero-shot-react-description",
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Load agent given tools and LLM.
|
||||
@ -28,6 +30,8 @@ def initialize_agent(
|
||||
llm: Language model to use as the agent.
|
||||
agent: The agent to use. Valid options are:
|
||||
`zero-shot-react-description`, `react-docstore`, `self-ask-with-search`.
|
||||
callback_manager: CallbackManager to use. Global callback manager is used if
|
||||
not provided. Defaults to None.
|
||||
**kwargs: Additional key word arguments to pass to the agent.
|
||||
|
||||
Returns:
|
||||
@ -39,5 +43,9 @@ def initialize_agent(
|
||||
f"Valid types are: {AGENT_TO_CLASS.keys()}."
|
||||
)
|
||||
agent_cls = AGENT_TO_CLASS[agent]
|
||||
agent_obj = agent_cls.from_llm_and_tools(llm, tools)
|
||||
return AgentExecutor.from_agent_and_tools(agent=agent_obj, tools=tools, **kwargs)
|
||||
agent_obj = agent_cls.from_llm_and_tools(
|
||||
llm, tools, callback_manager=callback_manager
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj, tools=tools, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel, Extra, Field
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
||||
|
||||
class Memory(BaseModel, ABC):
|
||||
@ -43,9 +44,21 @@ class Chain(BaseModel, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[Memory] = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
|
||||
verbose: bool = Field(default_factory=_get_verbosity)
|
||||
"""Whether to print out response text."""
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_callback_manager(self) -> BaseCallbackManager:
|
||||
"""Get the callback manager."""
|
||||
if self.callback_manager is not None:
|
||||
return self.callback_manager
|
||||
return get_callback_manager()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -110,9 +123,11 @@ class Chain(BaseModel, ABC):
|
||||
print(
|
||||
f"\n\n\033[1m> Entering new {self.__class__.__name__} chain...\033[0m"
|
||||
)
|
||||
get_callback_manager().on_chain_start({"name": self.__class__.__name__}, inputs)
|
||||
self._get_callback_manager().on_chain_start(
|
||||
{"name": self.__class__.__name__}, inputs
|
||||
)
|
||||
outputs = self._call(inputs)
|
||||
get_callback_manager().on_chain_end(outputs)
|
||||
self._get_callback_manager().on_chain_end(outputs)
|
||||
if self.verbose:
|
||||
print(f"\n\033[1m> Finished {self.__class__.__name__} chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
|
@ -9,6 +9,7 @@ from pydantic import BaseModel, Extra
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.schema import Generation, LLMResult
|
||||
|
||||
|
||||
@ -16,11 +17,13 @@ class BaseLLM(BaseModel, ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
cache: Optional[bool] = None
|
||||
callback_manager: Optional[BaseCallbackManager] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@ -28,6 +31,12 @@ class BaseLLM(BaseModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompts."""
|
||||
|
||||
def _get_callback_manager(self) -> BaseCallbackManager:
|
||||
"""Get the callback manager."""
|
||||
if self.callback_manager is not None:
|
||||
return self.callback_manager
|
||||
return get_callback_manager()
|
||||
|
||||
def generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
) -> LLMResult:
|
||||
@ -39,11 +48,11 @@ class BaseLLM(BaseModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
get_callback_manager().on_llm_start(
|
||||
self._get_callback_manager().on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts
|
||||
)
|
||||
output = self._generate(prompts, stop=stop)
|
||||
get_callback_manager().on_llm_end(output)
|
||||
self._get_callback_manager().on_llm_end(output)
|
||||
return output
|
||||
params = self._llm_dict()
|
||||
params["stop"] = stop
|
||||
@ -58,11 +67,11 @@ class BaseLLM(BaseModel, ABC):
|
||||
else:
|
||||
missing_prompts.append(prompt)
|
||||
missing_prompt_idxs.append(i)
|
||||
get_callback_manager().on_llm_start(
|
||||
self._get_callback_manager().on_llm_start(
|
||||
{"name": self.__class__.__name__}, missing_prompts
|
||||
)
|
||||
new_results = self._generate(missing_prompts, stop=stop)
|
||||
get_callback_manager().on_llm_end(new_results)
|
||||
self._get_callback_manager().on_llm_end(new_results)
|
||||
for i, result in enumerate(new_results.generations):
|
||||
existing_prompts[i] = result
|
||||
prompt = prompts[i]
|
||||
|
@ -4,8 +4,10 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import Tool, initialize_agent
|
||||
from langchain.agents import AgentExecutor, Tool, initialize_agent
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.llms.base import LLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeListLLM(LLM, BaseModel):
|
||||
@ -31,8 +33,8 @@ class FakeListLLM(LLM, BaseModel):
|
||||
return "fake_list"
|
||||
|
||||
|
||||
def test_agent_bad_action() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
def _get_agent(**kwargs: Any) -> AgentExecutor:
|
||||
"""Get agent for testing."""
|
||||
bad_action_name = "BadAction"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
@ -44,30 +46,50 @@ def test_agent_bad_action() -> None:
|
||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True
|
||||
tools, fake_llm, agent="zero-shot-react-description", verbose=True, **kwargs
|
||||
)
|
||||
return agent
|
||||
|
||||
|
||||
def test_agent_bad_action() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
agent = _get_agent()
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
|
||||
def test_agent_stopped_early() -> None:
|
||||
"""Test react chain when bad action given."""
|
||||
bad_action_name = "BadAction"
|
||||
agent = _get_agent(max_iterations=0)
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to max iterations."
|
||||
|
||||
|
||||
def test_agent_with_callbacks() -> None:
|
||||
"""Test react chain with callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"I'm turning evil\nAction: {bad_action_name}\nAction Input: misalignment",
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nAction: Final Answer\nAction Input: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses)
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
tools = [
|
||||
Tool("Search", lambda x: x, "Useful for searching"),
|
||||
Tool("Lookup", lambda x: x, "Useful for looking up things in a table"),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent="zero-shot-react-description",
|
||||
verbose=True,
|
||||
max_iterations=0,
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "Agent stopped due to max iterations."
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 6
|
||||
assert handler.ends == 6
|
||||
assert handler.errors == 0
|
||||
|
60
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
60
tests/unit_tests/callbacks/fake_callback_handler.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock callback handler."""
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
@ -1,69 +1,9 @@
|
||||
"""Test CallbackManager."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
BaseCallbackHandler,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager, CallbackManager
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the mock callback handler."""
|
||||
self.starts = 0
|
||||
self.ends = 0
|
||||
self.errors = 0
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **extra: str
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error(self, error: Exception) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **extra: str
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error(self, error: Exception) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], action: str, tool_input: str, **extra: str
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end(self, output: str) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error(self, error: Exception) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
|
@ -4,7 +4,9 @@ from typing import Any, Dict, List
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.chains.base import Chain, Memory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
class FakeMemory(Memory, BaseModel):
|
||||
@ -133,3 +135,14 @@ def test_run_arg_with_memory() -> None:
|
||||
"""Test run method works when arg is passed."""
|
||||
chain = FakeChain(the_input_keys=["foo", "baz"], memory=FakeMemory())
|
||||
chain.run("bar")
|
||||
|
||||
|
||||
def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager([handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
15
tests/unit_tests/llms/test_callbacks.py
Normal file
15
tests/unit_tests/llms/test_callbacks.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Test LLM callbacks."""
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
def test_llm_with_callbacks() -> None:
|
||||
"""Test LLM callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager([handler]))
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
Loading…
Reference in New Issue
Block a user