diff --git a/libs/core/langchain_core/outputs/chat_result.py b/libs/core/langchain_core/outputs/chat_result.py index 87eb094d27f..1cc814310e4 100644 --- a/libs/core/langchain_core/outputs/chat_result.py +++ b/libs/core/langchain_core/outputs/chat_result.py @@ -26,10 +26,11 @@ class ChatResult(BaseModel): """ llm_output: dict | None = None - """For arbitrary LLM provider specific output. + """For arbitrary model provider-specific output. This dictionary is a free-form dictionary that can contain any information that the - provider wants to return. It is not standardized and is provider-specific. + provider wants to return. It is not standardized and keys may vary by provider and + over time. Users should generally avoid relying on this field and instead rely on accessing relevant information from standardized fields present in `AIMessage`. diff --git a/libs/core/langchain_core/outputs/llm_result.py b/libs/core/langchain_core/outputs/llm_result.py index cf8e47ef3a9..df40c419750 100644 --- a/libs/core/langchain_core/outputs/llm_result.py +++ b/libs/core/langchain_core/outputs/llm_result.py @@ -38,10 +38,11 @@ class LLMResult(BaseModel): """ llm_output: dict | None = None - """For arbitrary LLM provider specific output. + """For arbitrary model provider-specific output. This dictionary is a free-form dictionary that can contain any information that the - provider wants to return. It is not standardized and is provider-specific. + provider wants to return. It is not standardized and keys may vary by provider and + over time. Users should generally avoid relying on this field and instead rely on accessing relevant information from standardized fields present in AIMessage. diff --git a/libs/core/langchain_core/tracers/base.py b/libs/core/langchain_core/tracers/base.py index 68579b3ff64..eaf9c04359f 100644 --- a/libs/core/langchain_core/tracers/base.py +++ b/libs/core/langchain_core/tracers/base.py @@ -61,7 +61,13 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): name: str | None = None, **kwargs: Any, ) -> Run: - """Start a trace for an LLM run. + """Start a trace for a chat model run. + + Note: + Naming can be confusing here: there is `on_chat_model_start`, but no + corresponding `on_chat_model_end` callback. Chat model completion is + routed through `on_llm_end` / `_on_llm_end`, which are shared with + text LLM runs. Args: serialized: The serialized model. @@ -191,7 +197,12 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC): @override def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: - """End a trace for an LLM run. + """End a trace for an LLM or chat model run. + + Note: + This is the end callback for both run types. Chat models start with + `on_chat_model_start`, but there is no `on_chat_model_end`; + completion is routed here for callback API compatibility. Args: response: The response. @@ -654,6 +665,14 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): tags: list[str] | None = None, **kwargs: Any, ) -> None: + """End a trace for an LLM or chat model run. + + Note: + This async callback also handles both run types. Async chat models + start with `on_chat_model_start`, but there is no + `on_chat_model_end`; completion is routed here for callback API + compatibility. + """ llm_run = self._complete_llm_run( response=response, run_id=run_id, @@ -874,7 +893,7 @@ class AsyncBaseTracer(_TracerCore, AsyncCallbackHandler, ABC): """Process the LLM Run upon start.""" async def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" + """Process LLM/chat model run completion.""" async def _on_llm_error(self, run: Run) -> None: """Process the LLM Run upon error.""" diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index a5122b7ffdc..2a2b5ce0f62 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from uuid import UUID from langsmith import Client, get_tracing_context @@ -77,11 +77,15 @@ def _get_usage_metadata_from_generations( """Extract and aggregate `usage_metadata` from generations. Iterates through generations to find and aggregate all `usage_metadata` found in - messages. This is typically present in chat model outputs. + messages. This expects the serialized message payload shape produced by tracer + internals: + + `{"message": {"kwargs": {"usage_metadata": {...}}}}` Args: generations: List of generation batches, where each batch is a list of - generation dicts that may contain a `'message'` key with `'usage_metadata'`. + generation dicts that may contain a `'message'` key with + usage metadata. Returns: The aggregated `usage_metadata` dict if found, otherwise `None`. @@ -91,11 +95,24 @@ def _get_usage_metadata_from_generations( for generation in generation_batch: if isinstance(generation, dict) and "message" in generation: message = generation["message"] - if isinstance(message, dict) and "usage_metadata" in message: - output = add_usage(output, message["usage_metadata"]) + usage_metadata = _get_usage_metadata_from_message(message) + if usage_metadata is not None: + output = add_usage(output, usage_metadata) return output +def _get_usage_metadata_from_message(message: Any) -> UsageMetadata | None: + """Extract usage metadata from a generation's message payload.""" + if not isinstance(message, dict): + return None + + kwargs = message.get("kwargs") + if isinstance(kwargs, dict) and isinstance(kwargs.get("usage_metadata"), dict): + return cast("UsageMetadata", kwargs["usage_metadata"]) + + return None + + class LangChainTracer(BaseTracer): """Implementation of the `SharedTracer` that `POSTS` to the LangChain endpoint.""" @@ -294,13 +311,19 @@ class LangChainTracer(BaseTracer): ) def _on_chat_model_start(self, run: Run) -> None: - """Persist an LLM run.""" + """Persist a chat model run. + + Note: + Naming is historical: there is no `_on_chat_model_end` hook. Chat + model completion is handled by `_on_llm_end`, shared with text + LLM runs. + """ if run.parent_run_id is None: run.reference_example_id = self.example_id self._persist_run_single(run) def _on_llm_end(self, run: Run) -> None: - """Process the LLM Run.""" + """Process LLM/chat model run completion.""" # Extract usage_metadata from outputs and store in extra.metadata if run.outputs and "generations" in run.outputs: usage_metadata = _get_usage_metadata_from_generations( diff --git a/libs/core/tests/unit_tests/tracers/test_langchain.py b/libs/core/tests/unit_tests/tracers/test_langchain.py index a2b729e1dae..d2a83e655a6 100644 --- a/libs/core/tests/unit_tests/tracers/test_langchain.py +++ b/libs/core/tests/unit_tests/tracers/test_langchain.py @@ -10,7 +10,8 @@ from langsmith import Client from langsmith.run_trees import RunTree from langsmith.utils import get_env_var, get_tracer_project -from langchain_core.outputs import LLMResult +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.tracers.langchain import ( LangChainTracer, _get_usage_metadata_from_generations, @@ -154,7 +155,8 @@ def test_correct_get_tracer_project( @pytest.mark.parametrize( ("generations", "expected"), [ - # Returns usage_metadata when present + # Returns None for non-serialized message usage_metadata shape + # (earlier regression) ( [ [ @@ -171,6 +173,33 @@ def test_correct_get_tracer_project( } ] ], + None, + ), + # Returns usage_metadata when message is serialized via dumpd + ( + [ + [ + { + "text": "Hello!", + "message": { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Hello!", + "type": "ai", + "usage_metadata": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + "tool_calls": [], + "invalid_tool_calls": [], + }, + }, + } + ] + ], {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, ), # Returns None when no usage_metadata @@ -187,22 +216,38 @@ def test_correct_get_tracer_project( { "text": "First", "message": { - "content": "First", - "usage_metadata": { - "input_tokens": 5, - "output_tokens": 10, - "total_tokens": 15, + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "First", + "type": "ai", + "usage_metadata": { + "input_tokens": 5, + "output_tokens": 10, + "total_tokens": 15, + }, + "tool_calls": [], + "invalid_tool_calls": [], }, }, }, { "text": "Second", "message": { - "content": "Second", - "usage_metadata": { - "input_tokens": 50, - "output_tokens": 100, - "total_tokens": 150, + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Second", + "type": "ai", + "usage_metadata": { + "input_tokens": 50, + "output_tokens": 100, + "total_tokens": 150, + }, + "tool_calls": [], + "invalid_tool_calls": [], }, }, }, @@ -218,11 +263,19 @@ def test_correct_get_tracer_project( { "text": "Has message", "message": { - "content": "Has message", - "usage_metadata": { - "input_tokens": 10, - "output_tokens": 20, - "total_tokens": 30, + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Has message", + "type": "ai", + "usage_metadata": { + "input_tokens": 10, + "output_tokens": 20, + "total_tokens": 30, + }, + "tool_calls": [], + "invalid_tool_calls": [], }, }, } @@ -232,7 +285,8 @@ def test_correct_get_tracer_project( ), ], ids=[ - "returns_usage_metadata_when_present", + "returns_none_when_non_serialized_message_shape", + "returns_usage_metadata_when_message_serialized", "returns_none_when_no_usage_metadata", "returns_none_when_no_message", "returns_none_for_empty_list", @@ -265,7 +319,18 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None: [ { "text": "Hello!", - "message": {"content": "Hello!", "usage_metadata": usage_metadata}, + "message": { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Hello!", + "type": "ai", + "usage_metadata": usage_metadata, + "tool_calls": [], + "invalid_tool_calls": [], + }, + }, } ] ] @@ -285,6 +350,41 @@ def test_on_llm_end_stores_usage_metadata_in_run_extra() -> None: assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata +def test_on_llm_end_stores_usage_metadata_from_serialized_outputs() -> None: + """Store `usage_metadata` from serialized generation message outputs.""" + client = unittest.mock.MagicMock(spec=Client) + client.tracing_queue = None + tracer = LangChainTracer(client=client) + + run_id = UUID("d94d0ff8-cf5a-4100-ab11-1a0efaa8d8d0") + tracer.on_llm_start({"name": "test_llm"}, ["foo"], run_id=run_id) + + usage_metadata = {"input_tokens": 100, "output_tokens": 200, "total_tokens": 300} + response = LLMResult( + generations=[ + [ + ChatGeneration( + message=AIMessage(content="Hello!", usage_metadata=usage_metadata) + ) + ] + ] + ) + run = tracer._complete_llm_run(response=response, run_id=run_id) + + captured_run = None + + def capture_run(r: Run) -> None: + nonlocal captured_run + captured_run = r + + with unittest.mock.patch.object(tracer, "_update_run_single", capture_run): + tracer._on_llm_end(run) + + assert captured_run is not None + assert "metadata" in captured_run.extra + assert captured_run.extra["metadata"]["usage_metadata"] == usage_metadata + + def test_on_llm_end_no_usage_metadata_when_not_present() -> None: """Test that no `usage_metadata` is added when not present in outputs.""" client = unittest.mock.MagicMock(spec=Client) @@ -296,7 +396,24 @@ def test_on_llm_end_no_usage_metadata_when_not_present() -> None: run = tracer.run_map[str(run_id)] run.outputs = { - "generations": [[{"text": "Hello!", "message": {"content": "Hello!"}}]] + "generations": [ + [ + { + "text": "Hello!", + "message": { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Hello!", + "type": "ai", + "tool_calls": [], + "invalid_tool_calls": [], + }, + }, + } + ] + ] } captured_run = None @@ -334,7 +451,18 @@ def test_on_llm_end_preserves_existing_metadata() -> None: [ { "text": "Hello!", - "message": {"content": "Hello!", "usage_metadata": usage_metadata}, + "message": { + "lc": 1, + "type": "constructor", + "id": ["langchain", "schema", "messages", "AIMessage"], + "kwargs": { + "content": "Hello!", + "type": "ai", + "usage_metadata": usage_metadata, + "tool_calls": [], + "invalid_tool_calls": [], + }, + }, } ] ]