Add on_chat_message_start (#4499)

### Add on_chat_message_start to callback manager and base tracer

Goal: trace messages directly to permit reloading as chat messages
(store in an integration-agnostic way)

Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for
handlers that don't have it implemented.

Does so in a non-backwards-compat breaking way (for now)
This commit is contained in:
Zander Chase
2023-05-11 11:06:39 -07:00
committed by GitHub
parent bbf76dbb52
commit 4ee47926ca
12 changed files with 311 additions and 140 deletions

View File

@@ -1,9 +1,12 @@
"""A fake callback handler for testing purposes."""
from typing import Any
from itertools import chain
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import BaseMessage
class BaseFakeCallbackHandler(BaseModel):
@@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
ignore_chat_model_: bool = False
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
@@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel):
tool_ends: int = 0
agent_actions: int = 0
agent_ends: int = 0
chat_model_starts: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.llm_streams += 1
def on_chain_start_common(self) -> None:
print("CHAIN START")
self.chain_starts += 1
self.starts += 1
@@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.errors += 1
def on_agent_action_common(self) -> None:
print("AGENT ACTION")
self.agent_actions += 1
self.starts += 1
@@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.agent_ends += 1
self.ends += 1
def on_chat_model_start_common(self) -> None:
print("STARTING CHAT MODEL")
self.chat_model_starts += 1
self.starts += 1
def on_text_common(self) -> None:
self.text += 1
@@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
return self
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
self.on_chat_model_start_common()
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake async callback handler for testing."""

View File

@@ -10,6 +10,7 @@ from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
@@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None:
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""

View File

@@ -1,70 +0,0 @@
"""Test LangChain+ Client Utils."""
from typing import List
from langchain.client.utils import parse_chat_messages
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def test_parse_chat_messages() -> None:
"""Test that chat messages are parsed correctly."""
input_text = (
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
)
expected = [
HumanMessage(content="I am human roar"),
AIMessage(content="I am AI beep boop"),
SystemMessage(content="I am a system message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_empty_input() -> None:
"""Test that an empty input string returns an empty list."""
input_text = ""
expected: List[BaseMessage] = []
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_multiline_messages() -> None:
"""Test that multiline messages are parsed correctly."""
input_text = (
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
" beep boop\nSystem: I am a system\nand a message"
)
expected = [
HumanMessage(content="I am a human\nand I roar"),
AIMessage(content="I am an AI\nand I beep boop"),
SystemMessage(content="I am a system\nand a message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_custom_roles() -> None:
"""Test that custom roles are parsed correctly."""
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
expected = [
ChatMessage(role="Client", content="I need help"),
ChatMessage(role="Agent", content="I'm here to help"),
ChatMessage(role="Client", content="Thank you"),
]
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
def test_parse_chat_messages_embedded_roles() -> None:
"""Test that messages with embedded role references are parsed correctly."""
input_text = (
"Human: Oh ai what if you said AI: foo bar?"
"\nAI: Well, that would be interesting!"
)
expected = [
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
AIMessage(content="Well, that would be interesting!"),
]
assert parse_chat_messages(input_text) == expected

View File

@@ -0,0 +1,32 @@
"""Fake Chat Model wrapper for testing purposes."""
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

View File

@@ -1,5 +1,10 @@
"""Test LLM callbacks."""
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from langchain.schema import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None:
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_chat_model_with_v1_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandler()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 1
assert handler.llm_ends == 1
def test_chat_model_with_v2_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandlerWithChatStart()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 0
assert handler.llm_ends == 1
assert handler.chat_model_starts == 1