Compare commits

...

5 Commits

Author SHA1 Message Date
Hunter Lovell
8178e9b053 fix: use message when applying output 2025-12-18 11:19:13 -08:00
Hunter Lovell
02ced19d71 chore(core): flatten generations for LangChainTracer 2025-12-18 11:03:16 -08:00
Hunter Lovell
95116652bf cr 2025-12-18 10:56:00 -08:00
Hunter Lovell
1b42a9af50 combine usage metadata 2025-12-18 10:50:56 -08:00
Hunter Lovell
be295899f9 feat(core): add usage_metadata to metadata in LangChainTracer 2025-12-17 22:16:46 -08:00
2 changed files with 417 additions and 1 deletions

View File

@@ -21,6 +21,7 @@ from typing_extensions import override
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.messages.ai import UsageMetadata, add_usage
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
@@ -69,6 +70,32 @@ def _get_executor() -> ThreadPoolExecutor:
return _EXECUTOR
def _get_usage_metadata_from_generations(
generations: list[list[dict[str, Any]]],
) -> UsageMetadata | None:
"""Extract and aggregate usage_metadata from generations.
Iterates through generations to find and aggregate all usage_metadata
found in messages. This is typically present in chat model outputs.
Args:
generations: List of generation batches, where each batch is a list
of generation dicts that may contain a "message" key with
"usage_metadata".
Returns:
The aggregated usage_metadata dict if found, otherwise None.
"""
output: UsageMetadata | None = None
for generation_batch in generations:
for generation in generation_batch:
if isinstance(generation, dict) and "message" in generation:
message = generation["message"]
if isinstance(message, dict) and "usage_metadata" in message:
output = add_usage(output, message["usage_metadata"])
return output
class LangChainTracer(BaseTracer):
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
@@ -266,6 +293,20 @@ class LangChainTracer(BaseTracer):
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
# Extract usage_metadata from outputs and store in extra.metadata
if run.outputs and "generations" in run.outputs:
generations = run.outputs["generations"]
usage_metadata = _get_usage_metadata_from_generations(generations)
if usage_metadata is not None:
if "metadata" not in run.extra:
run.extra["metadata"] = {}
run.extra["metadata"]["usage_metadata"] = usage_metadata
# Flatten outputs if there's only a single generation with a message
if len(generations) == 1 and len(generations[0]) == 1:
generation = generations[0][0]
if isinstance(generation, dict) and "message" in generation:
run.outputs = generation["message"]
self._update_run_single(run)
def _on_llm_error(self, run: Run) -> None:

View File

@@ -12,7 +12,10 @@ from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var, get_tracer_project
from langchain_core.outputs import LLMResult
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain import (
LangChainTracer,
_get_usage_metadata_from_generations,
)
from langchain_core.tracers.schemas import Run
@@ -145,3 +148,375 @@ def test_correct_get_tracer_project(
)
tracer.wait_for_futures()
assert projects == [expected_project_name]
@pytest.mark.parametrize(
("generations", "expected"),
[
# Returns usage_metadata when present
(
[
[
{
"text": "Hello!",
"message": {
"content": "Hello!",
"usage_metadata": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
},
},
}
]
],
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
),
# Returns None when no usage_metadata
([[{"text": "Hello!", "message": {"content": "Hello!"}}]], None),
# Returns None when no message
([[{"text": "Hello!"}]], None),
# Returns None for empty generations
([], None),
([[]], None),
# Aggregates usage_metadata across multiple generations
(
[
[
{
"text": "First",
"message": {
"content": "First",
"usage_metadata": {
"input_tokens": 5,
"output_tokens": 10,
"total_tokens": 15,
},
},
},
{
"text": "Second",
"message": {
"content": "Second",
"usage_metadata": {
"input_tokens": 50,
"output_tokens": 100,
"total_tokens": 150,
},
},
},
]
],
{"input_tokens": 55, "output_tokens": 110, "total_tokens": 165},
),
# Finds usage_metadata across multiple batches
(
[
[{"text": "No message here"}],
[
{
"text": "Has message",
"message": {
"content": "Has message",
"usage_metadata": {
"input_tokens": 10,
"output_tokens": 20,
"total_tokens": 30,
},
},
}
],
],
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
),
],
ids=[
"returns_usage_metadata_when_present",
"returns_none_when_no_usage_metadata",
"returns_none_when_no_message",
"returns_none_for_empty_list",
"returns_none_for_empty_batch",
"aggregates_across_multiple_generations",
"finds_across_multiple_batches",
],
)
def test_get_usage_metadata_from_generations(
generations: list[list[dict[str, Any]]], expected: dict[str, Any] | None
) -> None:
"""Test _get_usage_metadata_from_generations utility function."""
result = _get_usage_metadata_from_generations(generations)
assert result == expected
def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
"""Test that usage_metadata is stored in run.extra.metadata on llm end."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
usage_metadata = {"input_tokens": 100, "output_tokens": 200, "total_tokens": 300}
run.outputs = {
"generations": [
[
{
"text": "Hello!",
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
}
]
]
}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
assert "metadata" in captured_run.extra
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
"""Test that no usage_metadata is added when not present in outputs."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
run.outputs = {
"generations": [[{"text": "Hello!", "message": {"content": "Hello!"}}]]
}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
extra_metadata = captured_run.extra.get("metadata", {})
assert "usage_metadata" not in extra_metadata
def test_on_llm_end_preserves_existing_metadata() -> None:
"""Test that existing metadata is preserved when adding usage_metadata."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start(
{"name": "test_llm"},
["foo"],
run_id=run_id,
metadata={"existing_key": "existing_value"},
)
run = tracer.run_map[str(run_id)]
usage_metadata = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
run.outputs = {
"generations": [
[
{
"text": "Hello!",
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
}
]
]
}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
assert "metadata" in captured_run.extra
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
assert captured_run.extra["metadata"]["existing_key"] == "existing_value"
def test_on_llm_end_flattens_single_generation() -> None:
"""Test that outputs are flattened to just the message when generations is 1x1."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
message = {"content": "Hello!"}
generation = {"text": "Hello!", "message": message}
run.outputs = {"generations": [[generation]]}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should be flattened to just the message
assert captured_run.outputs == message
def test_on_llm_end_does_not_flatten_single_generation_without_message() -> None:
"""Test that outputs are not flattened when single generation has no message."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
# Generation with only text, no message
generation = {"text": "Hello!"}
run.outputs = {"generations": [[generation]]}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should NOT be flattened - keep original structure when no message
assert captured_run.outputs == {"generations": [[generation]]}
def test_on_llm_end_does_not_flatten_multiple_generations_in_batch() -> None:
"""Test that outputs are not flattened when there are multiple generations."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
generation1 = {"text": "Hello!", "message": {"content": "Hello!"}}
generation2 = {"text": "Hi there!", "message": {"content": "Hi there!"}}
run.outputs = {"generations": [[generation1, generation2]]}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should NOT be flattened - keep original structure
assert captured_run.outputs == {"generations": [[generation1, generation2]]}
def test_on_llm_end_does_not_flatten_multiple_batches() -> None:
"""Test that outputs are not flattened when there are multiple batches."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
run = tracer.run_map[str(run_id)]
generation1 = {"text": "Response 1", "message": {"content": "Response 1"}}
generation2 = {"text": "Response 2", "message": {"content": "Response 2"}}
run.outputs = {"generations": [[generation1], [generation2]]}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should NOT be flattened - keep original structure
assert captured_run.outputs == {"generations": [[generation1], [generation2]]}
def test_on_llm_end_does_not_flatten_multiple_batches_multiple_generations() -> None:
"""Test outputs not flattened with multiple batches and multiple generations."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
run = tracer.run_map[str(run_id)]
gen1a = {"text": "1a", "message": {"content": "1a"}}
gen1b = {"text": "1b", "message": {"content": "1b"}}
gen2a = {"text": "2a", "message": {"content": "2a"}}
gen2b = {"text": "2b", "message": {"content": "2b"}}
run.outputs = {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should NOT be flattened - keep original structure
assert captured_run.outputs == {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
def test_on_llm_end_handles_empty_generations() -> None:
"""Test that empty generations are handled without error."""
client = unittest.mock.MagicMock(spec=Client)
client.tracing_queue = None
tracer = LangChainTracer(client=client)
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
run = tracer.run_map[str(run_id)]
run.outputs = {"generations": []}
captured_run = None
def capture_run(r: Run) -> None:
nonlocal captured_run
captured_run = r
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
tracer._on_llm_end(run)
assert captured_run is not None
# Should keep original structure when empty
assert captured_run.outputs == {"generations": []}