mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
fix(core): extract usage metadata from serialized tracer message outputs (#35526)
Fixes missing `run.metadata.usage_metadata` population in `LangChainTracer` for real LLM/chat traces following #34414 - Fix extraction to read usage from serialized tracer message shape: `outputs.generations[*][*].message.kwargs.usage_metadata` - Remove non-serialized direct message shape handling (`message.usage_metadata`) from extractor to match real tracer output path - Clarify tracer docstrings around chat callback naming (`on_chat_model_start` + shared `on_llm_end`) to reduce ambiguity ## Why #34414 introduced usage duplication into `run.metadata.usage_metadata`, but the extractor read `message.usage_metadata`. In real tracer flow, messages are serialized with `dumpd(...)` during run completion, so usage metadata lives under `message.kwargs.usage_metadata`. Because of this mismatch, duplication did not trigger in real traces.
This commit is contained in:
@@ -26,10 +26,11 @@ class ChatResult(BaseModel):
|
||||
"""
|
||||
|
||||
llm_output: dict | None = None
|
||||
"""For arbitrary LLM provider specific output.
|
||||
"""For arbitrary model provider-specific output.
|
||||
|
||||
This dictionary is a free-form dictionary that can contain any information that the
|
||||
provider wants to return. It is not standardized and is provider-specific.
|
||||
provider wants to return. It is not standardized and keys may vary by provider and
|
||||
over time.
|
||||
|
||||
Users should generally avoid relying on this field and instead rely on accessing
|
||||
relevant information from standardized fields present in `AIMessage`.
|
||||
|
||||
@@ -38,10 +38,11 @@ class LLMResult(BaseModel):
|
||||
"""
|
||||
|
||||
llm_output: dict | None = None
|
||||
"""For arbitrary LLM provider specific output.
|
||||
"""For arbitrary model provider-specific output.
|
||||
|
||||
This dictionary is a free-form dictionary that can contain any information that the
|
||||
provider wants to return. It is not standardized and is provider-specific.
|
||||
provider wants to return. It is not standardized and keys may vary by provider and
|
||||
over time.
|
||||
|
||||
Users should generally avoid relying on this field and instead rely on accessing
|
||||
relevant information from standardized fields present in AIMessage.
|
||||
|
||||
@@ -61,7 +61,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run.
|
||||
"""Start a trace for a chat model run.
|
||||
|
||||
Note:
|
||||
Naming can be confusing here: there is `on_chat_model_start`, but no
|
||||
corresponding `on_chat_model_end` callback. Chat model completion is
|
||||
routed through `on_llm_end` / `_on_llm_end`, which are shared with
|
||||
text LLM runs.
|
||||
|
||||
Args:
|
||||
serialized: The serialized model.
|
||||
@@ -191,7 +197,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
|
||||
@override
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for an LLM run.
|
||||
"""End a trace for an LLM or chat model run.
|
||||
|
||||
Note:
|
||||
This is the end callback for both run types. Chat models start with
|
||||
`on_chat_model_start`, but there is no `on_chat_model_end`;
|
||||
completion is routed here for callback API compatibility.
|
||||
|
||||
Args:
|
||||
response: The response.
|
||||
@@ -654,6 +665,14 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""End a trace for an LLM or chat model run.
|
||||
|
||||
Note:
|
||||
This async callback also handles both run types. Async chat models
|
||||
start with `on_chat_model_start`, but there is no
|
||||
`on_chat_model_end`; completion is routed here for callback API
|
||||
compatibility.
|
||||
"""
|
||||
llm_run = self._complete_llm_run(
|
||||
response=response,
|
||||
run_id=run_id,
|
||||
@@ -874,7 +893,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
|
||||
"""Process the LLM Run upon start."""
|
||||
|
||||
async def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
"""Process LLM/chat model run completion."""
|
||||
|
||||
async def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client, get_tracing_context
|
||||
@@ -77,11 +77,15 @@ def _get_usage_metadata_from_generations(
|
||||
"""Extract and aggregate `usage_metadata` from generations.
|
||||
|
||||
Iterates through generations to find and aggregate all `usage_metadata` found in
|
||||
messages. This is typically present in chat model outputs.
|
||||
messages. This expects the serialized message payload shape produced by tracer
|
||||
internals:
|
||||
|
||||
`{"message": {"kwargs": {"usage_metadata": {...}}}}`
|
||||
|
||||
Args:
|
||||
generations: List of generation batches, where each batch is a list of
|
||||
generation dicts that may contain a `'message'` key with `'usage_metadata'`.
|
||||
generation dicts that may contain a `'message'` key with
|
||||
usage metadata.
|
||||
|
||||
Returns:
|
||||
The aggregated `usage_metadata` dict if found, otherwise `None`.
|
||||
@@ -91,11 +95,24 @@ def _get_usage_metadata_from_generations(
|
||||
for generation in generation_batch:
|
||||
if isinstance(generation, dict) and "message" in generation:
|
||||
message = generation["message"]
|
||||
if isinstance(message, dict) and "usage_metadata" in message:
|
||||
output = add_usage(output, message["usage_metadata"])
|
||||
usage_metadata = _get_usage_metadata_from_message(message)
|
||||
if usage_metadata is not None:
|
||||
output = add_usage(output, usage_metadata)
|
||||
return output
|
||||
|
||||
|
||||
def _get_usage_metadata_from_message(message: Any) -> UsageMetadata | None:
|
||||
"""Extract usage metadata from a generation's message payload."""
|
||||
if not isinstance(message, dict):
|
||||
return None
|
||||
|
||||
kwargs = message.get("kwargs")
|
||||
if isinstance(kwargs, dict) and isinstance(kwargs.get("usage_metadata"), dict):
|
||||
return cast("UsageMetadata", kwargs["usage_metadata"])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""Implementation of the `SharedTracer` that `POSTS` to the LangChain endpoint."""
|
||||
|
||||
@@ -294,13 +311,19 @@ class LangChainTracer(BaseTracer):
|
||||
)
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
"""Persist a chat model run.
|
||||
|
||||
Note:
|
||||
Naming is historical: there is no `_on_chat_model_end` hook. Chat
|
||||
model completion is handled by `_on_llm_end`, shared with text
|
||||
LLM runs.
|
||||
"""
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
self._persist_run_single(run)
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
"""Process LLM/chat model run completion."""
|
||||
# Extract usage_metadata from outputs and store in extra.metadata
|
||||
if run.outputs and "generations" in run.outputs:
|
||||
usage_metadata = _get_usage_metadata_from_generations(
|
||||
|
||||
@@ -10,7 +10,8 @@ from langsmith import Client
|
||||
from langsmith.run_trees import RunTree
|
||||
from langsmith.utils import get_env_var, get_tracer_project
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.tracers.langchain import (
|
||||
LangChainTracer,
|
||||
_get_usage_metadata_from_generations,
|
||||
@@ -154,7 +155,8 @@ def test_correct_get_tracer_project(
|
||||
@pytest.mark.parametrize(
|
||||
("generations", "expected"),
|
||||
[
|
||||
# Returns usage_metadata when present
|
||||
# Returns None for non-serialized message usage_metadata shape
|
||||
# (earlier regression)
|
||||
(
|
||||
[
|
||||
[
|
||||
@@ -171,6 +173,33 @@ def test_correct_get_tracer_project(
|
||||
}
|
||||
]
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Returns usage_metadata when message is serialized via dumpd
|
||||
(
|
||||
[
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Hello!",
|
||||
"type": "ai",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
],
|
||||
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
),
|
||||
# Returns None when no usage_metadata
|
||||
@@ -187,22 +216,38 @@ def test_correct_get_tracer_project(
|
||||
{
|
||||
"text": "First",
|
||||
"message": {
|
||||
"content": "First",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "First",
|
||||
"type": "ai",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Second",
|
||||
"message": {
|
||||
"content": "Second",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 50,
|
||||
"output_tokens": 100,
|
||||
"total_tokens": 150,
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Second",
|
||||
"type": "ai",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 50,
|
||||
"output_tokens": 100,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -218,11 +263,19 @@ def test_correct_get_tracer_project(
|
||||
{
|
||||
"text": "Has message",
|
||||
"message": {
|
||||
"content": "Has message",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Has message",
|
||||
"type": "ai",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -232,7 +285,8 @@ def test_correct_get_tracer_project(
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"returns_usage_metadata_when_present",
|
||||
"returns_none_when_non_serialized_message_shape",
|
||||
"returns_usage_metadata_when_message_serialized",
|
||||
"returns_none_when_no_usage_metadata",
|
||||
"returns_none_when_no_message",
|
||||
"returns_none_for_empty_list",
|
||||
@@ -265,7 +319,18 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
|
||||
"message": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Hello!",
|
||||
"type": "ai",
|
||||
"usage_metadata": usage_metadata,
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
]
|
||||
@@ -285,6 +350,41 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
|
||||
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
|
||||
|
||||
|
||||
def test_on_llm_end_stores_usage_metadata_from_serialized_outputs() -> None:
|
||||
"""Store `usage_metadata` from serialized generation message outputs."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("d94d0ff8-cf5a-4100-ab11-1a0efaa8d8d0")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
usage_metadata = {"input_tokens": 100, "output_tokens": 200, "total_tokens": 300}
|
||||
response = LLMResult(
|
||||
generations=[
|
||||
[
|
||||
ChatGeneration(
|
||||
message=AIMessage(content="Hello!", usage_metadata=usage_metadata)
|
||||
)
|
||||
]
|
||||
]
|
||||
)
|
||||
run = tracer._complete_llm_run(response=response, run_id=run_id)
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
assert "metadata" in captured_run.extra
|
||||
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
|
||||
|
||||
|
||||
def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
|
||||
"""Test that no `usage_metadata` is added when not present in outputs."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
@@ -296,7 +396,24 @@ def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
run.outputs = {
|
||||
"generations": [[{"text": "Hello!", "message": {"content": "Hello!"}}]]
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Hello!",
|
||||
"type": "ai",
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
|
||||
captured_run = None
|
||||
@@ -334,7 +451,18 @@ def test_on_llm_end_preserves_existing_metadata() -> None:
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
|
||||
"message": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": "Hello!",
|
||||
"type": "ai",
|
||||
"usage_metadata": usage_metadata,
|
||||
"tool_calls": [],
|
||||
"invalid_tool_calls": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user