fix(core): extract usage metadata from serialized tracer message outputs (#35526)

Fixes missing `run.metadata.usage_metadata` population in
`LangChainTracer` for real LLM/chat traces following #34414

- Fix extraction to read usage from serialized tracer message shape:
`outputs.generations[*][*].message.kwargs.usage_metadata`
- Remove non-serialized direct message shape handling
(`message.usage_metadata`) from extractor to match real tracer output
path
- Clarify tracer docstrings around chat callback naming
(`on_chat_model_start` + shared `on_llm_end`) to reduce ambiguity

## Why

#34414 introduced usage duplication into `run.metadata.usage_metadata`,
but the extractor read `message.usage_metadata`.

In real tracer flow, messages are serialized with `dumpd(...)` during
run completion, so usage metadata lives under
`message.kwargs.usage_metadata`. Because of this mismatch, duplication
did not trigger in real traces.
This commit is contained in:
Mason Daugherty
2026-03-02 17:43:33 -05:00
committed by GitHub
parent d2c86df128
commit 61fd90a2f3
5 changed files with 207 additions and 35 deletions

View File

@@ -26,10 +26,11 @@ class ChatResult(BaseModel):
"""
llm_output: dict | None = None
"""For arbitrary LLM provider specific output.
"""For arbitrary model provider-specific output.
This dictionary is a free-form dictionary that can contain any information that the
provider wants to return. It is not standardized and is provider-specific.
provider wants to return. It is not standardized and keys may vary by provider and
over time.
Users should generally avoid relying on this field and instead rely on accessing
relevant information from standardized fields present in `AIMessage`.

View File

@@ -38,10 +38,11 @@ class LLMResult(BaseModel):
"""
llm_output: dict | None = None
"""For arbitrary LLM provider specific output.
"""For arbitrary model provider-specific output.
This dictionary is a free-form dictionary that can contain any information that the
provider wants to return. It is not standardized and is provider-specific.
provider wants to return. It is not standardized and keys may vary by provider and
over time.
Users should generally avoid relying on this field and instead rely on accessing
relevant information from standardized fields present in AIMessage.

View File

@@ -61,7 +61,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
name: str | None = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run.
"""Start a trace for a chat model run.
Note:
Naming can be confusing here: there is `on_chat_model_start`, but no
corresponding `on_chat_model_end` callback. Chat model completion is
routed through `on_llm_end` / `_on_llm_end`, which are shared with
text LLM runs.
Args:
serialized: The serialized model.
@@ -191,7 +197,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
@override
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run.
"""End a trace for an LLM or chat model run.
Note:
This is the end callback for both run types. Chat models start with
`on_chat_model_start`, but there is no `on_chat_model_end`;
completion is routed here for callback API compatibility.
Args:
response: The response.
@@ -654,6 +665,14 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""End a trace for an LLM or chat model run.
Note:
This async callback also handles both run types. Async chat models
start with `on_chat_model_start`, but there is no
`on_chat_model_end`; completion is routed here for callback API
compatibility.
"""
llm_run = self._complete_llm_run(
response=response,
run_id=run_id,
@@ -874,7 +893,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC):
"""Process the LLM Run upon start."""
async def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
"""Process LLM/chat model run completion."""
async def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from langsmith import Client, get_tracing_context
@@ -77,11 +77,15 @@ def _get_usage_metadata_from_generations(
"""Extract and aggregate `usage_metadata` from generations.
Iterates through generations to find and aggregate all `usage_metadata` found in
messages. This is typically present in chat model outputs.
messages. This expects the serialized message payload shape produced by tracer
internals:
`{"message": {"kwargs": {"usage_metadata": {...}}}}`
Args:
generations: List of generation batches, where each batch is a list of
generation dicts that may contain a `'message'` key with `'usage_metadata'`.
generation dicts that may contain a `'message'` key with
usage metadata.
Returns:
The aggregated `usage_metadata` dict if found, otherwise `None`.
@@ -91,11 +95,24 @@ def _get_usage_metadata_from_generations(
for generation in generation_batch:
if isinstance(generation, dict) and "message" in generation:
message = generation["message"]
if isinstance(message, dict) and "usage_metadata" in message:
output = add_usage(output, message["usage_metadata"])
usage_metadata = _get_usage_metadata_from_message(message)
if usage_metadata is not None:
output = add_usage(output, usage_metadata)
return output
def _get_usage_metadata_from_message(message: Any) -> UsageMetadata | None:
"""Extract usage metadata from a generation's message payload."""
if not isinstance(message, dict):
return None
kwargs = message.get("kwargs")
if isinstance(kwargs, dict) and isinstance(kwargs.get("usage_metadata"), dict):
return cast("UsageMetadata", kwargs["usage_metadata"])
return None
class LangChainTracer(BaseTracer):
"""Implementation of the `SharedTracer` that `POSTS` to the LangChain endpoint."""
@@ -294,13 +311,19 @@ class LangChainTracer(BaseTracer):
)
def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
"""Persist a chat model run.
Note:
Naming is historical: there is no `_on_chat_model_end` hook. Chat
model completion is handled by `_on_llm_end`, shared with text
LLM runs.
"""
if run.parent_run_id is None:
run.reference_example_id = self.example_id
self._persist_run_single(run)
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
"""Process LLM/chat model run completion."""
# Extract usage_metadata from outputs and store in extra.metadata
if run.outputs and "generations" in run.outputs:
usage_metadata = _get_usage_metadata_from_generations(

View File

@@ -10,7 +10,8 @@ from langsmith import Client
from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var, get_tracer_project
from langchain_core.outputs import LLMResult
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.tracers.langchain import (
LangChainTracer,
_get_usage_metadata_from_generations,
@@ -154,7 +155,8 @@ def test_correct_get_tracer_project(
@pytest.mark.parametrize(
("generations", "expected"),
[
# Returns usage_metadata when present
# Returns None for non-serialized message usage_metadata shape
# (earlier regression)
(
[
[
@@ -171,6 +173,33 @@ def test_correct_get_tracer_project(
}
]
],
None,
),
# Returns usage_metadata when message is serialized via dumpd
(
[
[
{
"text": "Hello!",
"message": {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Hello!",
"type": "ai",
"usage_metadata": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
},
"tool_calls": [],
"invalid_tool_calls": [],
},
},
}
]
],
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
),
# Returns None when no usage_metadata
@@ -187,22 +216,38 @@ def test_correct_get_tracer_project(
{
"text": "First",
"message": {
"content": "First",
"usage_metadata": {
"input_tokens": 5,
"output_tokens": 10,
"total_tokens": 15,
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "First",
"type": "ai",
"usage_metadata": {
"input_tokens": 5,
"output_tokens": 10,
"total_tokens": 15,
},
"tool_calls": [],
"invalid_tool_calls": [],
},
},
},
{
"text": "Second",
"message": {
"content": "Second",
"usage_metadata": {
"input_tokens": 50,
"output_tokens": 100,
"total_tokens": 150,
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Second",
"type": "ai",
"usage_metadata": {
"input_tokens": 50,
"output_tokens": 100,
"total_tokens": 150,
},
"tool_calls": [],
"invalid_tool_calls": [],
},
},
},
@@ -218,11 +263,19 @@ def test_correct_get_tracer_project(
{
"text": "Has message",
"message": {
"content": "Has message",
"usage_metadata": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Has message",
"type": "ai",
"usage_metadata": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
},
"tool_calls": [],
"invalid_tool_calls": [],
},
},
}
@@ -232,7 +285,8 @@ def test_correct_get_tracer_project(
),
],
ids=[
"returns_usage_metadata_when_present",
"returns_none_when_non_serialized_message_shape",
"returns_usage_metadata_when_message_serialized",
"returns_none_when_no_usage_metadata",
"returns_none_when_no_message",
"returns_none_for_empty_list",
@@ -265,7 +319,18 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
[
{
"text": "Hello!",
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
"message": {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Hello!",
"type": "ai",
"usage_metadata": usage_metadata,
"tool_calls": [],
"invalid_tool_calls": [],
},
},
}
]
]
@@ -285,6 +350,41 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
def test_on_llm_end_stores_usage_metadata_from_serialized_outputs() -> None:
"""Store `usage_metadata` from serialized generation message outputs."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("d94d0ff8-cf5a-4100-ab11-1a0efaa8d8d0")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
usage_metadata = {"input_tokens": 100, "output_tokens": 200, "total_tokens": 300}
response = LLMResult(
generations=[
[
ChatGeneration(
message=AIMessage(content="Hello!", usage_metadata=usage_metadata)
)
]
]
)
run = tracer._complete_llm_run(response=response, run_id=run_id)
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
assert "metadata" in captured_run.extra
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
"""Test that no `usage_metadata` is added when not present in outputs."""
client = unittest.mock.MagicMock(spec=Client)
@@ -296,7 +396,24 @@ def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
run = tracer.run_map[str(run_id)]
run.outputs = {
"generations": [[{"text": "Hello!", "message": {"content": "Hello!"}}]]
"generations": [
[
{
"text": "Hello!",
"message": {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Hello!",
"type": "ai",
"tool_calls": [],
"invalid_tool_calls": [],
},
},
}
]
]
}
captured_run = None
@@ -334,7 +451,18 @@ def test_on_llm_end_preserves_existing_metadata() -> None:
[
{
"text": "Hello!",
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
"message": {
"lc": 1,
"type": "constructor",
"id": ["langchain", "schema", "messages", "AIMessage"],
"kwargs": {
"content": "Hello!",
"type": "ai",
"usage_metadata": usage_metadata,
"tool_calls": [],
"invalid_tool_calls": [],
},
},
}
]
]