mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-18 08:03:36 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
@@ -15,7 +15,12 @@ class FakeListLLM(LLM):
|
||||
responses: List[str]
|
||||
i: int = -1
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
print(f"=== Mock Response #{self.i} ===")
|
||||
@@ -82,135 +87,57 @@ def test_agent_stopped_early() -> None:
|
||||
assert output == "Agent stopped due to iteration limit or time limit."
|
||||
|
||||
|
||||
def test_agent_with_callbacks_global() -> None:
|
||||
def test_agent_with_callbacks() -> None:
|
||||
"""Test react chain with callbacks by setting verbose globally."""
|
||||
import langchain
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
|
||||
langchain.verbose = True
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
# Only fake LLM gets callbacks for handler2
|
||||
fake_llm = FakeListLLM(responses=responses, callbacks=[handler2])
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
callback_manager=manager,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
output = agent.run("when was langchain made", callbacks=[handler1])
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
assert handler1.chain_starts == handler1.chain_ends == 3
|
||||
assert handler1.llm_starts == handler1.llm_ends == 2
|
||||
assert handler1.tool_starts == 1
|
||||
assert handler1.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
assert handler1.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
assert handler1.ends == 7
|
||||
assert handler1.errors == 0
|
||||
# during LLMChain
|
||||
assert handler.text == 2
|
||||
assert handler1.text == 2
|
||||
|
||||
|
||||
def test_agent_with_callbacks_local() -> None:
|
||||
"""Test react chain with callbacks by setting verbose locally."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True)
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
callback_manager=manager,
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
callback_manager=manager,
|
||||
assert handler2.llm_starts == 2
|
||||
assert handler2.llm_ends == 2
|
||||
assert (
|
||||
handler2.chain_starts
|
||||
== handler2.tool_starts
|
||||
== handler2.tool_ends
|
||||
== handler2.chain_ends
|
||||
== 0
|
||||
)
|
||||
|
||||
agent.agent.llm_chain.verbose = True # type: ignore
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run
|
||||
assert handler.chain_starts == handler.chain_ends == 3
|
||||
assert handler.llm_starts == handler.llm_ends == 2
|
||||
assert handler.tool_starts == 2
|
||||
assert handler.tool_ends == 1
|
||||
# 1 extra agent action
|
||||
assert handler.starts == 7
|
||||
# 1 extra agent end
|
||||
assert handler.ends == 7
|
||||
assert handler.errors == 0
|
||||
# during LLMChain
|
||||
assert handler.text == 2
|
||||
|
||||
|
||||
def test_agent_with_callbacks_not_verbose() -> None:
|
||||
"""Test react chain with callbacks but not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
handler = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler])
|
||||
tool = "Search"
|
||||
responses = [
|
||||
f"FooBarBaz\nAction: {tool}\nAction Input: misalignment",
|
||||
"Oh well\nFinal Answer: curses foiled again",
|
||||
]
|
||||
fake_llm = FakeListLLM(responses=responses, callback_manager=manager)
|
||||
tools = [
|
||||
Tool(
|
||||
name="Search",
|
||||
func=lambda x: x,
|
||||
description="Useful for searching",
|
||||
),
|
||||
]
|
||||
agent = initialize_agent(
|
||||
tools,
|
||||
fake_llm,
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager=manager,
|
||||
)
|
||||
|
||||
output = agent.run("when was langchain made")
|
||||
assert output == "curses foiled again"
|
||||
|
||||
# 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_agent_tool_return_direct() -> None:
|
||||
"""Test agent using tools that return directly."""
|
||||
|
@@ -4,6 +4,7 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
from langchain.agents.react.base import ReActChain, ReActDocstoreAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
@@ -32,7 +33,12 @@ class FakeListLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake_list"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Increment counter, and then return response in that index."""
|
||||
self.i += 1
|
||||
return self.responses[self.i]
|
||||
|
@@ -171,7 +171,7 @@ def test_decorated_function_schema_equivalent() -> None:
|
||||
def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
"""Test functionality with structured arguments parsed as a decorator."""
|
||||
|
||||
@tool
|
||||
@tool(infer_schema=False)
|
||||
def structured_tool_input(
|
||||
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
|
||||
) -> str:
|
||||
@@ -182,7 +182,8 @@ def test_structured_args_decorator_no_infer_schema() -> None:
|
||||
assert structured_tool_input.name == "structured_tool_input"
|
||||
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
|
||||
expected_result = "1, 0.001, {'foo': 'bar'}"
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
with pytest.raises(ValueError):
|
||||
assert structured_tool_input.run(args) == expected_result
|
||||
|
||||
|
||||
def test_structured_single_str_decorator_no_infer_schema() -> None:
|
||||
|
@@ -1,10 +1,9 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -17,12 +16,72 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
always_verbose_: bool = False
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return self.always_verbose_
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_actions: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
"""Base fake callback handler mixin for testing."""
|
||||
|
||||
def on_llm_start_common(self) -> None:
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_llm_end_common(self) -> None:
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_llm_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_llm_new_token_common(self) -> None:
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_chain_end_common(self) -> None:
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chain_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_tool_start_common(self) -> None:
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_tool_end_common(self) -> None:
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_tool_error_common(self) -> None:
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_action_common(self) -> None:
|
||||
self.agent_actions += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_agent_finish_common(self) -> None:
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_text_common(self) -> None:
|
||||
self.text += 1
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
@@ -39,164 +98,209 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
chain_ends: int = 0
|
||||
llm_starts: int = 0
|
||||
llm_ends: int = 0
|
||||
llm_streams: int = 0
|
||||
tool_starts: int = 0
|
||||
tool_ends: int = 0
|
||||
agent_ends: int = 0
|
||||
|
||||
|
||||
class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler):
|
||||
"""Fake callback handler for testing."""
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_start_common()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_end_common()
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_llm_error_common()
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_start_common()
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_end_common()
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_chain_error_common()
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_start_common()
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_end_common()
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_tool_error_common()
|
||||
|
||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_action_common()
|
||||
|
||||
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||
return self
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler):
|
||||
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
@property
|
||||
def ignore_llm(self) -> bool:
|
||||
"""Whether to ignore LLM callbacks."""
|
||||
return self.ignore_llm_
|
||||
|
||||
@property
|
||||
def ignore_chain(self) -> bool:
|
||||
"""Whether to ignore chain callbacks."""
|
||||
return self.ignore_chain_
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
self.llm_starts += 1
|
||||
self.starts += 1
|
||||
self.on_llm_start_common()
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run when LLM generates a new token."""
|
||||
self.llm_streams += 1
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.llm_ends += 1
|
||||
self.ends += 1
|
||||
async def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_llm_end_common()
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.errors += 1
|
||||
self.on_llm_error_common()
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
self.on_chain_start_common()
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running."""
|
||||
self.chain_ends += 1
|
||||
self.ends += 1
|
||||
async def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_chain_end_common()
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain errors."""
|
||||
self.errors += 1
|
||||
self.on_chain_error_common()
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
self.on_tool_start_common()
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running."""
|
||||
self.tool_ends += 1
|
||||
self.ends += 1
|
||||
async def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_tool_end_common()
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool errors."""
|
||||
self.errors += 1
|
||||
self.on_tool_error_common()
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> None:
|
||||
"""Run when agent is ending."""
|
||||
self.text += 1
|
||||
async def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_action_common()
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
|
||||
"""Run when agent ends running."""
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
|
||||
"""Run on agent action."""
|
||||
self.tool_starts += 1
|
||||
self.starts += 1
|
||||
async def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.on_text_common()
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":
|
||||
return self
|
||||
|
@@ -1,15 +1,12 @@
|
||||
"""Test CallbackManager."""
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import (
|
||||
AsyncCallbackManager,
|
||||
BaseCallbackManager,
|
||||
CallbackManager,
|
||||
)
|
||||
from langchain.callbacks.shared import SharedCallbackManager
|
||||
from langchain.schema import AgentFinish, LLMResult
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
BaseFakeCallbackHandler,
|
||||
FakeAsyncCallbackHandler,
|
||||
@@ -18,19 +15,26 @@ from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
|
||||
|
||||
def _test_callback_manager(
|
||||
manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [])
|
||||
manager.on_llm_end(LLMResult(generations=[]))
|
||||
manager.on_llm_error(Exception())
|
||||
manager.on_chain_start({"name": "foo"}, {})
|
||||
manager.on_chain_end({})
|
||||
manager.on_chain_error(Exception())
|
||||
manager.on_tool_start({}, "")
|
||||
manager.on_tool_end("")
|
||||
manager.on_tool_error(Exception())
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager_chain.on_chain_end({})
|
||||
run_manager_chain.on_chain_error(Exception())
|
||||
run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool=""))
|
||||
run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager_chain.on_text("foo")
|
||||
|
||||
run_manager_tool = manager.on_tool_start({}, "")
|
||||
run_manager_tool.on_tool_end("")
|
||||
run_manager_tool.on_tool_error(Exception())
|
||||
run_manager_tool.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
@@ -38,75 +42,62 @@ async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
await manager.on_llm_start({}, [])
|
||||
await manager.on_llm_end(LLMResult(generations=[]))
|
||||
await manager.on_llm_error(Exception())
|
||||
await manager.on_chain_start({"name": "foo"}, {})
|
||||
await manager.on_chain_end({})
|
||||
await manager.on_chain_error(Exception())
|
||||
await manager.on_tool_start({}, "")
|
||||
await manager.on_tool_end("")
|
||||
await manager.on_tool_error(Exception())
|
||||
await manager.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
run_manager = await manager.on_llm_start({}, [])
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
|
||||
await run_manager_chain.on_chain_end({})
|
||||
await run_manager_chain.on_chain_error(Exception())
|
||||
await run_manager_chain.on_agent_action(
|
||||
AgentAction(tool_input="foo", log="", tool="")
|
||||
)
|
||||
await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={}))
|
||||
await run_manager_chain.on_text("foo")
|
||||
|
||||
run_manager_tool = await manager.on_tool_start({}, "")
|
||||
await run_manager_tool.on_tool_end("")
|
||||
await run_manager_tool.on_tool_error(Exception())
|
||||
await run_manager_tool.on_text("foo")
|
||||
_check_num_calls(handlers)
|
||||
|
||||
|
||||
def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None:
|
||||
for handler in handlers:
|
||||
if handler.always_verbose:
|
||||
assert handler.starts == 3
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
else:
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def _test_callback_manager_pass_in_verbose(
|
||||
manager: BaseCallbackManager, *handlers: FakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True)
|
||||
for handler in handlers:
|
||||
assert handler.starts == 3
|
||||
assert handler.starts == 4
|
||||
assert handler.ends == 4
|
||||
assert handler.errors == 3
|
||||
assert handler.text == 3
|
||||
|
||||
assert handler.llm_starts == 1
|
||||
assert handler.llm_ends == 1
|
||||
assert handler.llm_streams == 1
|
||||
|
||||
assert handler.chain_starts == 1
|
||||
assert handler.chain_ends == 1
|
||||
|
||||
assert handler.tool_starts == 1
|
||||
assert handler.tool_ends == 1
|
||||
|
||||
|
||||
def test_callback_manager() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=False)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_callback_manager_pass_in_verbose() -> None:
|
||||
"""Test the CallbackManager."""
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager([handler1, handler2])
|
||||
_test_callback_manager_pass_in_verbose(manager, handler1, handler2)
|
||||
|
||||
|
||||
def test_ignore_llm() -> None:
|
||||
"""Test ignore llm param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_llm_start({}, [], verbose=True)
|
||||
manager.on_llm_end(LLMResult(generations=[]), verbose=True)
|
||||
manager.on_llm_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -117,12 +108,12 @@ def test_ignore_llm() -> None:
|
||||
|
||||
def test_ignore_chain() -> None:
|
||||
"""Test ignore chain param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_chain_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_chain_start({"name": "foo"}, {}, verbose=True)
|
||||
manager.on_chain_end({}, verbose=True)
|
||||
manager.on_chain_error(Exception(), verbose=True)
|
||||
run_manager = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager.on_chain_end({})
|
||||
run_manager.on_chain_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
@@ -133,39 +124,24 @@ def test_ignore_chain() -> None:
|
||||
|
||||
def test_ignore_agent() -> None:
|
||||
"""Test ignore agent param for callback handlers."""
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler(ignore_agent_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
manager.on_tool_start({}, "", verbose=True)
|
||||
manager.on_tool_end("", verbose=True)
|
||||
manager.on_tool_error(Exception(), verbose=True)
|
||||
manager.on_agent_finish(AgentFinish({}, ""), verbose=True)
|
||||
run_manager = manager.on_tool_start({}, "")
|
||||
run_manager.on_tool_end("")
|
||||
run_manager.on_tool_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
assert handler2.starts == 1
|
||||
assert handler2.ends == 2
|
||||
assert handler2.ends == 1
|
||||
assert handler2.errors == 1
|
||||
|
||||
|
||||
def test_shared_callback_manager() -> None:
|
||||
"""Test the SharedCallbackManager."""
|
||||
manager1 = SharedCallbackManager()
|
||||
manager2 = SharedCallbackManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager1.add_handler(handler1)
|
||||
manager2.add_handler(handler2)
|
||||
_test_callback_manager(manager1, handler1, handler2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeAsyncCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2])
|
||||
await _test_callback_manager_async(manager, handler1, handler2)
|
||||
@@ -174,8 +150,95 @@ async def test_async_callback_manager() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_callback_manager_sync_handler() -> None:
|
||||
"""Test the AsyncCallbackManager."""
|
||||
handler1 = FakeCallbackHandler(always_verbose_=True)
|
||||
handler1 = FakeCallbackHandler()
|
||||
handler2 = FakeAsyncCallbackHandler()
|
||||
handler3 = FakeAsyncCallbackHandler(always_verbose_=True)
|
||||
handler3 = FakeAsyncCallbackHandler()
|
||||
manager = AsyncCallbackManager([handler1, handler2, handler3])
|
||||
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
||||
|
||||
|
||||
def test_callback_manager_inheritance() -> None:
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
callback_manager1 = CallbackManager([handler1, handler2])
|
||||
assert callback_manager1.handlers == [handler1, handler2]
|
||||
assert callback_manager1.inheritable_handlers == []
|
||||
|
||||
callback_manager2 = CallbackManager([])
|
||||
assert callback_manager2.handlers == []
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.set_handlers([handler1, handler2])
|
||||
assert callback_manager2.handlers == [handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1, handler2]
|
||||
|
||||
callback_manager2.set_handlers([handler3, handler4], inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4]
|
||||
assert callback_manager2.inheritable_handlers == []
|
||||
|
||||
callback_manager2.add_handler(handler1)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
callback_manager2.add_handler(handler2, inherit=False)
|
||||
assert callback_manager2.handlers == [handler3, handler4, handler1, handler2]
|
||||
assert callback_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
run_manager = callback_manager2.on_chain_start({"name": "foo"}, {})
|
||||
child_manager = run_manager.get_child()
|
||||
assert child_manager.handlers == [handler1]
|
||||
assert child_manager.inheritable_handlers == [handler1]
|
||||
|
||||
run_manager_tool = child_manager.on_tool_start({}, "")
|
||||
assert run_manager_tool.handlers == [handler1]
|
||||
assert run_manager_tool.inheritable_handlers == [handler1]
|
||||
|
||||
child_manager2 = run_manager_tool.get_child()
|
||||
assert child_manager2.handlers == [handler1]
|
||||
assert child_manager2.inheritable_handlers == [handler1]
|
||||
|
||||
|
||||
def test_callback_manager_configure() -> None:
|
||||
"""Test callback manager configuration."""
|
||||
handler1, handler2, handler3, handler4 = (
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
FakeCallbackHandler(),
|
||||
)
|
||||
|
||||
inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2]
|
||||
local_callbacks: List[BaseCallbackHandler] = [handler3, handler4]
|
||||
configured_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=local_callbacks,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert len(configured_manager.handlers) == 5
|
||||
assert len(configured_manager.inheritable_handlers) == 2
|
||||
assert configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks
|
||||
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
||||
assert isinstance(configured_manager, CallbackManager)
|
||||
|
||||
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
|
||||
async_configured_manager = AsyncCallbackManager.configure(
|
||||
inheritable_callbacks=inheritable_callbacks,
|
||||
local_callbacks=async_local_callbacks,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert len(async_configured_manager.handlers) == 4
|
||||
assert len(async_configured_manager.inheritable_handlers) == 2
|
||||
assert async_configured_manager.inheritable_handlers == inheritable_callbacks
|
||||
assert async_configured_manager.handlers == inheritable_callbacks + [
|
||||
handler3,
|
||||
handler4,
|
||||
]
|
||||
assert isinstance(async_configured_manager, AsyncCallbackManager)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@@ -12,9 +12,7 @@ from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
LLMRun,
|
||||
SharedTracer,
|
||||
ToolRun,
|
||||
Tracer,
|
||||
TracerException,
|
||||
TracerSession,
|
||||
)
|
||||
@@ -24,88 +22,6 @@ from langchain.schema import LLMResult
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]:
|
||||
return ChainRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _perform_nested_run(tracer: BaseTracer) -> None:
|
||||
"""Perform a nested run."""
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_chain_end(outputs={})
|
||||
|
||||
|
||||
def _add_child_run(
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
parent_run.child_runs.append(child_run)
|
||||
|
||||
|
||||
def _generate_id() -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return None
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return TracerSession(id=1, name=session_name, start_time=datetime.utcnow())
|
||||
@@ -121,7 +37,7 @@ def load_default_session() -> TracerSession:
|
||||
return TracerSession(id=1, name="default", start_time=datetime.utcnow())
|
||||
|
||||
|
||||
class FakeTracer(Tracer):
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -133,58 +49,6 @@ class FakeTracer(Tracer):
|
||||
"""Persist a run."""
|
||||
self.runs.append(run)
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
|
||||
def load_session(self, session_name: str) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_session(session_name)
|
||||
|
||||
def load_default_session(self) -> TracerSession:
|
||||
"""Load a tracing session."""
|
||||
return load_default_session()
|
||||
|
||||
|
||||
class FakeSharedTracer(SharedTracer):
|
||||
"""Fake shared tracer that records LangChain execution."""
|
||||
|
||||
runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||
|
||||
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
|
||||
"""Persist a run."""
|
||||
with self._lock:
|
||||
self.runs.append(run)
|
||||
|
||||
def remove_runs(self) -> None:
|
||||
"""Remove all runs."""
|
||||
with self._lock:
|
||||
self.runs = []
|
||||
|
||||
def _add_child_run(
|
||||
self,
|
||||
parent_run: Union[ChainRun, ToolRun],
|
||||
child_run: Union[LLMRun, ChainRun, ToolRun],
|
||||
) -> None:
|
||||
"""Add child run to a chain run or tool run."""
|
||||
_add_child_run(parent_run, child_run)
|
||||
|
||||
def _generate_id(self) -> Optional[Union[int, str]]:
|
||||
"""Generate an id for a run."""
|
||||
return _generate_id()
|
||||
|
||||
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
|
||||
"""Persist a tracing session."""
|
||||
return _persist_session(session)
|
||||
@@ -201,12 +65,15 @@ class FakeSharedTracer(SharedTracer):
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run() -> None:
|
||||
"""Test tracer on an LLM run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -216,20 +83,11 @@ def test_tracer_llm_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_session() -> None:
|
||||
"""Test tracer on an LLM run without a session."""
|
||||
tracer = FakeTracer()
|
||||
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
@@ -237,18 +95,21 @@ def test_tracer_llm_run_errors_no_start() -> None:
|
||||
|
||||
tracer.new_session()
|
||||
with pytest.raises(TracerException):
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_multiple_llm_runs() -> None:
|
||||
"""Test the tracer with multiple runs."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
@@ -260,8 +121,8 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
|
||||
@@ -269,12 +130,15 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chain_run() -> None:
|
||||
"""Test tracer on a Chain run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
@@ -284,20 +148,23 @@ def test_tracer_chain_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_end(outputs={})
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_end(outputs={}, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_tool_run() -> None:
|
||||
"""Test tracer on a Tool run."""
|
||||
uuid = uuid4()
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
@@ -308,8 +175,8 @@ def test_tracer_tool_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_end("test")
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_end("test", run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -318,21 +185,109 @@ def test_tracer_nested_run() -> None:
|
||||
"""Test tracer on a nested run."""
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
for _ in range(10):
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output="test",
|
||||
action="{}",
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(tool_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
error=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert tracer.runs == [compare_run] * 10
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_on_error() -> None:
|
||||
"""Test tracer on an LLM run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = LLMRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[],
|
||||
response=None,
|
||||
@@ -342,8 +297,8 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -351,13 +306,16 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
def test_tracer_chain_run_on_error() -> None:
|
||||
"""Test tracer on a Chain run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs=None,
|
||||
@@ -367,8 +325,8 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_chain_error(exception)
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid)
|
||||
tracer.on_chain_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -376,13 +334,16 @@ def test_tracer_chain_run_on_error() -> None:
|
||||
def test_tracer_tool_run_on_error() -> None:
|
||||
"""Test tracer on a Tool run with an error."""
|
||||
exception = Exception("test")
|
||||
uuid = uuid4()
|
||||
|
||||
compare_run = ToolRun(
|
||||
id=None,
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
tool_input="test",
|
||||
output=None,
|
||||
@@ -393,8 +354,8 @@ def test_tracer_tool_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid)
|
||||
tracer.on_tool_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@@ -405,37 +366,53 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
|
||||
tracer = FakeTracer()
|
||||
tracer.new_session()
|
||||
chain_uuid = uuid4()
|
||||
tool_uuid = uuid4()
|
||||
llm_uuid1 = uuid4()
|
||||
llm_uuid2 = uuid4()
|
||||
llm_uuid3 = uuid4()
|
||||
|
||||
for _ in range(3):
|
||||
tracer.on_chain_start(serialized={}, inputs={})
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
tracer.on_tool_start(serialized={}, input_str="test")
|
||||
tracer.on_llm_start(serialized={}, prompts=[])
|
||||
tracer.on_llm_error(exception)
|
||||
tracer.on_tool_error(exception)
|
||||
tracer.on_chain_error(exception)
|
||||
tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
|
||||
tracer.on_tool_start(
|
||||
serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid
|
||||
)
|
||||
tracer.on_llm_error(exception, run_id=llm_uuid3)
|
||||
tracer.on_tool_error(exception, run_id=tool_uuid)
|
||||
tracer.on_chain_error(exception, run_id=chain_uuid)
|
||||
|
||||
compare_run = ChainRun(
|
||||
id=None,
|
||||
uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
inputs={},
|
||||
outputs=None,
|
||||
child_runs=[
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=str(llm_uuid1),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
@@ -443,36 +420,45 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=str(llm_uuid2),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]], llm_output=None),
|
||||
),
|
||||
],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[
|
||||
ToolRun(
|
||||
id=None,
|
||||
uuid=str(tool_uuid),
|
||||
parent_uuid=str(chain_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
tool_input="test",
|
||||
output=None,
|
||||
action="{}",
|
||||
child_runs=[
|
||||
child_llm_runs=[
|
||||
LLMRun(
|
||||
id=None,
|
||||
uuid=str(llm_uuid3),
|
||||
parent_uuid=str(tool_uuid),
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={},
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
@@ -480,43 +466,10 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
response=None,
|
||||
)
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
),
|
||||
],
|
||||
child_llm_runs=[],
|
||||
child_chain_runs=[],
|
||||
child_tool_runs=[],
|
||||
)
|
||||
|
||||
assert tracer.runs == [compare_run] * 3
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.new_session()
|
||||
tracer.remove_runs()
|
||||
_perform_nested_run(tracer)
|
||||
assert tracer.runs == [_get_compare_run()]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_shared_tracer_nested_run_multithreaded() -> None:
|
||||
"""Test shared tracer on a nested run."""
|
||||
tracer = FakeSharedTracer()
|
||||
tracer.remove_runs()
|
||||
tracer.new_session()
|
||||
threads = []
|
||||
num_threads = 10
|
||||
for _ in range(num_threads):
|
||||
thread = threading.Thread(target=_perform_nested_run, args=(tracer,))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert tracer.runs == [_get_compare_run()] * num_threads
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@@ -25,11 +25,9 @@ class FakeMemory(BaseMemory):
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Pass."""
|
||||
pass
|
||||
|
||||
|
||||
class FakeChain(Chain):
|
||||
@@ -49,7 +47,11 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
if self.be_correct:
|
||||
return {"bar": "baz"}
|
||||
else:
|
||||
@@ -143,25 +145,10 @@ def test_run_with_callback() -> None:
|
||||
"""Test run method works when callback manager is passed."""
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(
|
||||
callback_manager=CallbackManager(handlers=[handler]), verbose=True
|
||||
callbacks=[handler],
|
||||
)
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_run_with_callback_not_verbose() -> None:
|
||||
"""Test run method works when callback manager is passed and not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
chain = FakeChain(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = chain.run("bar")
|
||||
assert output == "baz"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
@@ -3,6 +3,10 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -28,12 +32,18 @@ class FakeLLM(BaseLLM):
|
||||
n: int = 1
|
||||
|
||||
def _generate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
async def _agenerate(
|
||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
|
@@ -3,8 +3,8 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE
|
||||
from langchain.chains.llm_bash.base import LLMBashChain
|
||||
from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser
|
||||
from langchain.schema import OutputParserException
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_simple_question() -> None:
|
||||
prompt = _PROMPT_TEMPLATE.format(question=question)
|
||||
queries = {prompt: "```bash\nexpr 1 + 1\n```"}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
output = fake_llm_bash_chain.run(question)
|
||||
assert output == "2\n"
|
||||
|
||||
@@ -71,7 +71,7 @@ echo 'hello world'
|
||||
"""
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
with pytest.raises(OutputParserException):
|
||||
fake_llm_bash_chain.run(question)
|
||||
|
||||
|
@@ -33,7 +33,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain:
|
||||
): "I still don't know.",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMCheckerChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
|
||||
|
@@ -17,7 +17,7 @@ def fake_llm_math_chain() -> LLMMathChain:
|
||||
_PROMPT_TEMPLATE.format(question="foo"): "foo",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMMathChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a")
|
||||
|
||||
|
||||
def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None:
|
||||
|
@@ -32,7 +32,9 @@ def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain:
|
||||
): "True",
|
||||
}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a")
|
||||
return LLMSummarizationCheckerChain.from_llm(
|
||||
fake_llm, input_key="q", output_key="a"
|
||||
)
|
||||
|
||||
|
||||
def test_simple_text(
|
||||
|
@@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chains.natbot.base import NatBotChain
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
@@ -9,7 +10,12 @@ from langchain.llms.base import LLM
|
||||
class FakeLLM(LLM):
|
||||
"""Fake LLM wrapper for testing purposes."""
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
"""Return `foo` if longer than 10000 words, else `bar`."""
|
||||
if len(prompt) > 10000:
|
||||
return "foo"
|
||||
@@ -28,7 +34,7 @@ class FakeLLM(LLM):
|
||||
|
||||
def test_proper_inputs() -> None:
|
||||
"""Test that natbot shortens inputs correctly."""
|
||||
nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing")
|
||||
nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing")
|
||||
url = "foo" * 10000
|
||||
browser_content = "foo" * 10000
|
||||
output = nat_bot_chain.execute(url, browser_content)
|
||||
@@ -37,8 +43,8 @@ def test_proper_inputs() -> None:
|
||||
|
||||
def test_variable_key_naming() -> None:
|
||||
"""Test that natbot handles variable key naming correctly."""
|
||||
nat_bot_chain = NatBotChain(
|
||||
llm=FakeLLM(),
|
||||
nat_bot_chain = NatBotChain.from_llm(
|
||||
FakeLLM(),
|
||||
objective="testing",
|
||||
input_url_key="u",
|
||||
input_browser_content_key="b",
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""Test pipeline functionality."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
from langchain.memory.simple import SimpleMemory
|
||||
@@ -24,7 +25,11 @@ class FakeChain(Chain):
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
outputs = {}
|
||||
for var in self.output_variables:
|
||||
variables = [inputs[k] for k in self.input_variables]
|
||||
|
@@ -3,6 +3,7 @@ from typing import Any, List, Mapping, Optional, cast
|
||||
|
||||
from pydantic import validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
@@ -28,7 +29,12 @@ class FakeLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
if self.sequential_responses:
|
||||
return self._get_next_response_in_sequence
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
"""Test LLM callbacks."""
|
||||
from langchain.callbacks.base import CallbackManager
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
@@ -7,24 +6,9 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
def test_llm_with_callbacks() -> None:
|
||||
"""Test LLM callbacks."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True)
|
||||
llm = FakeLLM(callbacks=[handler], verbose=True)
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_llm_with_callbacks_not_verbose() -> None:
|
||||
"""Test LLM callbacks but not verbose."""
|
||||
import langchain
|
||||
|
||||
langchain.verbose = False
|
||||
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]))
|
||||
output = llm("foo")
|
||||
assert output == "foo"
|
||||
assert handler.starts == 0
|
||||
assert handler.ends == 0
|
||||
assert handler.errors == 0
|
||||
|
45
tests/unit_tests/tools/test_signatures.py
Normal file
45
tests/unit_tests/tools/test_signatures.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Test base tool child implementations."""
|
||||
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import List, Type
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.playwright.base import BaseBrowserTool
|
||||
|
||||
|
||||
def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
|
||||
to_skip = {BaseBrowserTool} # Abstract but not recognized
|
||||
subclasses = []
|
||||
for subclass in cls.__subclasses__():
|
||||
if (
|
||||
not getattr(subclass, "__abstract__", None)
|
||||
and not subclass.__name__.startswith("_")
|
||||
and subclass not in to_skip
|
||||
):
|
||||
subclasses.append(subclass)
|
||||
subclasses.extend(get_non_abstract_subclasses(subclass))
|
||||
return subclasses
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool)) # type: ignore
|
||||
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
|
||||
"""Test that tools defined in this repo accept a run manager argument."""
|
||||
# This wouldn't be necessary if the BaseTool had a strict API.
|
||||
if cls._run is not BaseTool._arun:
|
||||
run_func = cls._run
|
||||
params = inspect.signature(run_func).parameters
|
||||
assert "run_manager" in params
|
||||
pattern = re.compile(r"(?!Async)CallbackManagerForToolRun")
|
||||
assert bool(re.search(pattern, str(params["run_manager"].annotation)))
|
||||
assert params["run_manager"].default is None
|
||||
|
||||
if cls._arun is not BaseTool._arun:
|
||||
run_func = cls._arun
|
||||
params = inspect.signature(run_func).parameters
|
||||
assert "run_manager" in params
|
||||
assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation)
|
||||
assert params["run_manager"].default is None
|
Reference in New Issue
Block a user