mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 12:11:34 +00:00
Compare commits
5 Commits
cc/release
...
hunter/fla
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8178e9b053 | ||
|
|
02ced19d71 | ||
|
|
95116652bf | ||
|
|
1b42a9af50 | ||
|
|
be295899f9 |
@@ -21,6 +21,7 @@ from typing_extensions import override
|
||||
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.messages.ai import UsageMetadata, add_usage
|
||||
from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@@ -69,6 +70,32 @@ def _get_executor() -> ThreadPoolExecutor:
|
||||
return _EXECUTOR
|
||||
|
||||
|
||||
def _get_usage_metadata_from_generations(
|
||||
generations: list[list[dict[str, Any]]],
|
||||
) -> UsageMetadata | None:
|
||||
"""Extract and aggregate usage_metadata from generations.
|
||||
|
||||
Iterates through generations to find and aggregate all usage_metadata
|
||||
found in messages. This is typically present in chat model outputs.
|
||||
|
||||
Args:
|
||||
generations: List of generation batches, where each batch is a list
|
||||
of generation dicts that may contain a "message" key with
|
||||
"usage_metadata".
|
||||
|
||||
Returns:
|
||||
The aggregated usage_metadata dict if found, otherwise None.
|
||||
"""
|
||||
output: UsageMetadata | None = None
|
||||
for generation_batch in generations:
|
||||
for generation in generation_batch:
|
||||
if isinstance(generation, dict) and "message" in generation:
|
||||
message = generation["message"]
|
||||
if isinstance(message, dict) and "usage_metadata" in message:
|
||||
output = add_usage(output, message["usage_metadata"])
|
||||
return output
|
||||
|
||||
|
||||
class LangChainTracer(BaseTracer):
|
||||
"""Implementation of the SharedTracer that POSTS to the LangChain endpoint."""
|
||||
|
||||
@@ -266,6 +293,20 @@ class LangChainTracer(BaseTracer):
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
# Extract usage_metadata from outputs and store in extra.metadata
|
||||
if run.outputs and "generations" in run.outputs:
|
||||
generations = run.outputs["generations"]
|
||||
usage_metadata = _get_usage_metadata_from_generations(generations)
|
||||
if usage_metadata is not None:
|
||||
if "metadata" not in run.extra:
|
||||
run.extra["metadata"] = {}
|
||||
run.extra["metadata"]["usage_metadata"] = usage_metadata
|
||||
|
||||
# Flatten outputs if there's only a single generation with a message
|
||||
if len(generations) == 1 and len(generations[0]) == 1:
|
||||
generation = generations[0][0]
|
||||
if isinstance(generation, dict) and "message" in generation:
|
||||
run.outputs = generation["message"]
|
||||
self._update_run_single(run)
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
|
||||
@@ -12,7 +12,10 @@ from langsmith.run_trees import RunTree
|
||||
from langsmith.utils import get_env_var, get_tracer_project
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
from langchain_core.tracers.langchain import (
|
||||
LangChainTracer,
|
||||
_get_usage_metadata_from_generations,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
|
||||
@@ -145,3 +148,375 @@ def test_correct_get_tracer_project(
|
||||
)
|
||||
tracer.wait_for_futures()
|
||||
assert projects == [expected_project_name]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("generations", "expected"),
|
||||
[
|
||||
# Returns usage_metadata when present
|
||||
(
|
||||
[
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {
|
||||
"content": "Hello!",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
],
|
||||
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
),
|
||||
# Returns None when no usage_metadata
|
||||
([[{"text": "Hello!", "message": {"content": "Hello!"}}]], None),
|
||||
# Returns None when no message
|
||||
([[{"text": "Hello!"}]], None),
|
||||
# Returns None for empty generations
|
||||
([], None),
|
||||
([[]], None),
|
||||
# Aggregates usage_metadata across multiple generations
|
||||
(
|
||||
[
|
||||
[
|
||||
{
|
||||
"text": "First",
|
||||
"message": {
|
||||
"content": "First",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 5,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Second",
|
||||
"message": {
|
||||
"content": "Second",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 50,
|
||||
"output_tokens": 100,
|
||||
"total_tokens": 150,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
],
|
||||
{"input_tokens": 55, "output_tokens": 110, "total_tokens": 165},
|
||||
),
|
||||
# Finds usage_metadata across multiple batches
|
||||
(
|
||||
[
|
||||
[{"text": "No message here"}],
|
||||
[
|
||||
{
|
||||
"text": "Has message",
|
||||
"message": {
|
||||
"content": "Has message",
|
||||
"usage_metadata": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
],
|
||||
{"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"returns_usage_metadata_when_present",
|
||||
"returns_none_when_no_usage_metadata",
|
||||
"returns_none_when_no_message",
|
||||
"returns_none_for_empty_list",
|
||||
"returns_none_for_empty_batch",
|
||||
"aggregates_across_multiple_generations",
|
||||
"finds_across_multiple_batches",
|
||||
],
|
||||
)
|
||||
def test_get_usage_metadata_from_generations(
|
||||
generations: list[list[dict[str, Any]]], expected: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Test _get_usage_metadata_from_generations utility function."""
|
||||
result = _get_usage_metadata_from_generations(generations)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None:
|
||||
"""Test that usage_metadata is stored in run.extra.metadata on llm end."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
usage_metadata = {"input_tokens": 100, "output_tokens": 200, "total_tokens": 300}
|
||||
run.outputs = {
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
assert "metadata" in captured_run.extra
|
||||
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
|
||||
|
||||
|
||||
def test_on_llm_end_no_usage_metadata_when_not_present() -> None:
|
||||
"""Test that no usage_metadata is added when not present in outputs."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
run.outputs = {
|
||||
"generations": [[{"text": "Hello!", "message": {"content": "Hello!"}}]]
|
||||
}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
extra_metadata = captured_run.extra.get("metadata", {})
|
||||
assert "usage_metadata" not in extra_metadata
|
||||
|
||||
|
||||
def test_on_llm_end_preserves_existing_metadata() -> None:
|
||||
"""Test that existing metadata is preserved when adding usage_metadata."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start(
|
||||
{"name": "test_llm"},
|
||||
["foo"],
|
||||
run_id=run_id,
|
||||
metadata={"existing_key": "existing_value"},
|
||||
)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
usage_metadata = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
|
||||
run.outputs = {
|
||||
"generations": [
|
||||
[
|
||||
{
|
||||
"text": "Hello!",
|
||||
"message": {"content": "Hello!", "usage_metadata": usage_metadata},
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
assert "metadata" in captured_run.extra
|
||||
assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata
|
||||
assert captured_run.extra["metadata"]["existing_key"] == "existing_value"
|
||||
|
||||
|
||||
def test_on_llm_end_flattens_single_generation() -> None:
|
||||
"""Test that outputs are flattened to just the message when generations is 1x1."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
message = {"content": "Hello!"}
|
||||
generation = {"text": "Hello!", "message": message}
|
||||
run.outputs = {"generations": [[generation]]}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should be flattened to just the message
|
||||
assert captured_run.outputs == message
|
||||
|
||||
|
||||
def test_on_llm_end_does_not_flatten_single_generation_without_message() -> None:
|
||||
"""Test that outputs are not flattened when single generation has no message."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
# Generation with only text, no message
|
||||
generation = {"text": "Hello!"}
|
||||
run.outputs = {"generations": [[generation]]}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should NOT be flattened - keep original structure when no message
|
||||
assert captured_run.outputs == {"generations": [[generation]]}
|
||||
|
||||
|
||||
def test_on_llm_end_does_not_flatten_multiple_generations_in_batch() -> None:
|
||||
"""Test that outputs are not flattened when there are multiple generations."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
generation1 = {"text": "Hello!", "message": {"content": "Hello!"}}
|
||||
generation2 = {"text": "Hi there!", "message": {"content": "Hi there!"}}
|
||||
run.outputs = {"generations": [[generation1, generation2]]}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should NOT be flattened - keep original structure
|
||||
assert captured_run.outputs == {"generations": [[generation1, generation2]]}
|
||||
|
||||
|
||||
def test_on_llm_end_does_not_flatten_multiple_batches() -> None:
|
||||
"""Test that outputs are not flattened when there are multiple batches."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
generation1 = {"text": "Response 1", "message": {"content": "Response 1"}}
|
||||
generation2 = {"text": "Response 2", "message": {"content": "Response 2"}}
|
||||
run.outputs = {"generations": [[generation1], [generation2]]}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should NOT be flattened - keep original structure
|
||||
assert captured_run.outputs == {"generations": [[generation1], [generation2]]}
|
||||
|
||||
|
||||
def test_on_llm_end_does_not_flatten_multiple_batches_multiple_generations() -> None:
|
||||
"""Test outputs not flattened with multiple batches and multiple generations."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo", "bar"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
gen1a = {"text": "1a", "message": {"content": "1a"}}
|
||||
gen1b = {"text": "1b", "message": {"content": "1b"}}
|
||||
gen2a = {"text": "2a", "message": {"content": "2a"}}
|
||||
gen2b = {"text": "2b", "message": {"content": "2b"}}
|
||||
run.outputs = {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should NOT be flattened - keep original structure
|
||||
assert captured_run.outputs == {"generations": [[gen1a, gen1b], [gen2a, gen2b]]}
|
||||
|
||||
|
||||
def test_on_llm_end_handles_empty_generations() -> None:
|
||||
"""Test that empty generations are handled without error."""
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
tracer = LangChainTracer(client=client)
|
||||
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id)
|
||||
|
||||
run = tracer.run_map[str(run_id)]
|
||||
run.outputs = {"generations": []}
|
||||
|
||||
captured_run = None
|
||||
|
||||
def capture_run(r: Run) -> None:
|
||||
nonlocal captured_run
|
||||
captured_run = r
|
||||
|
||||
with unittest.mock.patch.object(tracer, "_update_run_single", capture_run):
|
||||
tracer._on_llm_end(run)
|
||||
|
||||
assert captured_run is not None
|
||||
# Should keep original structure when empty
|
||||
assert captured_run.outputs == {"generations": []}
|
||||
|
||||
Reference in New Issue
Block a user