mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
Add on_chat_message_start (#4499)
### Add on_chat_message_start to callback manager and base tracer Goal: trace messages directly to permit reloading as chat messages (store in an integration-agnostic way) Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for handlers that don't have it implemented. Does so in a non-backwards-compat breaking way (for now)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
"""A fake callback handler for testing purposes."""
|
||||
from typing import Any
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
ignore_llm_: bool = False
|
||||
ignore_chain_: bool = False
|
||||
ignore_agent_: bool = False
|
||||
ignore_chat_model_: bool = False
|
||||
|
||||
# add finer-grained counters for easier debugging of failing tests
|
||||
chain_starts: int = 0
|
||||
@@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
||||
tool_ends: int = 0
|
||||
agent_actions: int = 0
|
||||
agent_ends: int = 0
|
||||
chat_model_starts: int = 0
|
||||
|
||||
|
||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
@@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.llm_streams += 1
|
||||
|
||||
def on_chain_start_common(self) -> None:
|
||||
print("CHAIN START")
|
||||
self.chain_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
@@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.errors += 1
|
||||
|
||||
def on_agent_action_common(self) -> None:
|
||||
print("AGENT ACTION")
|
||||
self.agent_actions += 1
|
||||
self.starts += 1
|
||||
|
||||
@@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||
self.agent_ends += 1
|
||||
self.ends += 1
|
||||
|
||||
def on_chat_model_start_common(self) -> None:
|
||||
print("STARTING CHAT MODEL")
|
||||
self.chat_model_starts += 1
|
||||
self.starts += 1
|
||||
|
||||
def on_text_common(self) -> None:
|
||||
self.text += 1
|
||||
|
||||
@@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
return self
|
||||
|
||||
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
|
||||
self.on_chat_model_start_common()
|
||||
|
||||
|
||||
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Fake async callback handler for testing."""
|
||||
|
||||
|
@@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import (
|
||||
BaseTracer,
|
||||
ChainRun,
|
||||
@@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None:
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
prompts=[""],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_llm_run_errors_no_start() -> None:
|
||||
"""Test tracer on an LLM run without a start."""
|
||||
|
@@ -1,70 +0,0 @@
|
||||
"""Test LangChain+ Client Utils."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from langchain.client.utils import parse_chat_messages
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_chat_messages() -> None:
|
||||
"""Test that chat messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am human roar"),
|
||||
AIMessage(content="I am AI beep boop"),
|
||||
SystemMessage(content="I am a system message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_input() -> None:
|
||||
"""Test that an empty input string returns an empty list."""
|
||||
input_text = ""
|
||||
expected: List[BaseMessage] = []
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiline_messages() -> None:
|
||||
"""Test that multiline messages are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
|
||||
" beep boop\nSystem: I am a system\nand a message"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="I am a human\nand I roar"),
|
||||
AIMessage(content="I am an AI\nand I beep boop"),
|
||||
SystemMessage(content="I am a system\nand a message"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_custom_roles() -> None:
|
||||
"""Test that custom roles are parsed correctly."""
|
||||
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
|
||||
expected = [
|
||||
ChatMessage(role="Client", content="I need help"),
|
||||
ChatMessage(role="Agent", content="I'm here to help"),
|
||||
ChatMessage(role="Client", content="Thank you"),
|
||||
]
|
||||
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
|
||||
|
||||
|
||||
def test_parse_chat_messages_embedded_roles() -> None:
|
||||
"""Test that messages with embedded role references are parsed correctly."""
|
||||
input_text = (
|
||||
"Human: Oh ai what if you said AI: foo bar?"
|
||||
"\nAI: Well, that would be interesting!"
|
||||
)
|
||||
expected = [
|
||||
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
|
||||
AIMessage(content="Well, that would be interesting!"),
|
||||
]
|
||||
assert parse_chat_messages(input_text) == expected
|
32
tests/unit_tests/llms/fake_chat_model.py
Normal file
32
tests/unit_tests/llms/fake_chat_model.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import SimpleChatModel
|
||||
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
||||
|
||||
|
||||
class FakeChatModel(SimpleChatModel):
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> str:
|
||||
return "fake response"
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
) -> ChatResult:
|
||||
output_str = "fake response"
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
@@ -1,5 +1,10 @@
|
||||
"""Test LLM callbacks."""
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
from langchain.schema import HumanMessage
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
FakeCallbackHandler,
|
||||
FakeCallbackHandlerWithChatStart,
|
||||
)
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None:
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
|
||||
|
||||
def test_chat_model_with_v1_callbacks() -> None:
|
||||
"""Test chat model callbacks fall back to on_llm_start."""
|
||||
handler = FakeCallbackHandler()
|
||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||
output = llm([HumanMessage(content="foo")])
|
||||
assert output.content == "fake response"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
assert handler.llm_starts == 1
|
||||
assert handler.llm_ends == 1
|
||||
|
||||
|
||||
def test_chat_model_with_v2_callbacks() -> None:
|
||||
"""Test chat model callbacks fall back to on_llm_start."""
|
||||
handler = FakeCallbackHandlerWithChatStart()
|
||||
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||
output = llm([HumanMessage(content="foo")])
|
||||
assert output.content == "fake response"
|
||||
assert handler.starts == 1
|
||||
assert handler.ends == 1
|
||||
assert handler.errors == 0
|
||||
assert handler.llm_starts == 0
|
||||
assert handler.llm_ends == 1
|
||||
assert handler.chat_model_starts == 1
|
||||
|
Reference in New Issue
Block a user