Compare commits

..

5 Commits

Author SHA1 Message Date
Eugene Yurtsev
03241dc80c x 2026-04-16 13:51:45 -04:00
Eugene Yurtsev
9c4de86bca Merge branch 'master' into jacob/traceablemetadata 2026-04-16 13:40:26 -04:00
jacoblee93
bc0e99d045 Remove test 2026-04-15 15:39:31 -07:00
jacoblee93
91b1ef049c Feedback 2026-04-15 15:10:16 -07:00
jacoblee93
c993ba06bb Add chat model and LLM invocation params to traceable metadata 2026-04-15 13:43:41 -07:00
5 changed files with 239 additions and 6 deletions

View File

@@ -2,6 +2,7 @@ import re
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypedDict,
TypeVar,
@@ -14,6 +15,21 @@ from langchain_core.messages.content import (
)
def _filter_invocation_params_for_tracing(params: dict[str, Any]) -> dict[str, Any]:
"""Filter out large/inappropriate fields from invocation params for tracing.
Removes fields like tools, functions, messages, response_format that can be large.
Args:
params: The invocation parameters to filter.
Returns:
The filtered parameters with large fields removed.
"""
excluded_keys = {"tools", "functions", "messages", "response_format"}
return {k: v for k, v in params.items() if k not in excluded_keys}
def is_openai_data_block(
block: dict, filter_: Literal["image", "audio", "file"] | None = None
) -> bool:

View File

@@ -25,6 +25,7 @@ from langchain_core.callbacks import (
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models._utils import (
_filter_invocation_params_for_tracing,
_normalize_messages,
_update_message_content_to_blocks,
)
@@ -567,6 +568,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
(run_manager,) = callback_manager.on_chat_model_start(
self._serialized,
@@ -695,6 +699,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
(run_manager,) = await callback_manager.on_chat_model_start(
self._serialized,
@@ -972,6 +979,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
messages_to_trace = [
_format_for_tracing(message_list) for message_list in messages
@@ -1095,6 +1105,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
messages_to_trace = [

View File

@@ -42,6 +42,7 @@ from langchain_core.callbacks import (
Callbacks,
)
from langchain_core.globals import get_llm_cache
from langchain_core.language_models._utils import _filter_invocation_params_for_tracing
from langchain_core.language_models.base import (
BaseLanguageModel,
LangSmithParams,
@@ -537,6 +538,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
(run_manager,) = callback_manager.on_llm_start(
self._serialized,
@@ -607,6 +611,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
inheritable_metadata,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
(run_manager,) = await callback_manager.on_llm_start(
self._serialized,
@@ -950,6 +957,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_name_list = run_name or cast(
"list[str | None]", ([None] * len(prompts))
)
params = self.dict()
params["stop"] = stop
callback_managers = [
CallbackManager.configure(
callback,
@@ -959,6 +968,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
meta,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
for callback, tag, meta in zip(
callbacks, tags_list, metadata_list, strict=False
@@ -966,6 +978,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
]
else:
# We've received a single callbacks arg to apply to all inputs
params = self.dict()
params["stop"] = stop
callback_managers = [
CallbackManager.configure(
cast("Callbacks", callbacks),
@@ -975,12 +989,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
cast("dict[str, Any]", metadata),
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
] * len(prompts)
run_name_list = [cast("str | None", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
@@ -1214,6 +1229,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_name_list = run_name or cast(
"list[str | None]", ([None] * len(prompts))
)
params = self.dict()
params["stop"] = stop
callback_managers = [
AsyncCallbackManager.configure(
callback,
@@ -1223,6 +1240,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
meta,
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
for callback, tag, meta in zip(
callbacks, tags_list, metadata_list, strict=False
@@ -1230,6 +1250,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
]
else:
# We've received a single callbacks arg to apply to all inputs
params = self.dict()
params["stop"] = stop
callback_managers = [
AsyncCallbackManager.configure(
cast("Callbacks", callbacks),
@@ -1239,12 +1261,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
cast("dict[str, Any]", metadata),
self.metadata,
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
params
),
)
] * len(prompts)
run_name_list = [cast("str | None", run_name)] * len(prompts)
run_ids_list = self._get_run_ids_list(run_id, prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,

View File

@@ -17,8 +17,14 @@ from langchain_core.language_models import (
FakeListChatModel,
ParrotFakeChatModel,
)
from langchain_core.language_models._utils import _normalize_messages
from langchain_core.language_models.chat_models import _generate_response_from_error
from langchain_core.language_models._utils import (
_filter_invocation_params_for_tracing,
_normalize_messages,
)
from langchain_core.language_models.chat_models import (
SimpleChatModel,
_generate_response_from_error,
)
from langchain_core.language_models.fake_chat_models import (
FakeListChatModelError,
GenericFakeChatModel,
@@ -1390,3 +1396,86 @@ def test_generate_response_from_error_handles_streaming_response_failure() -> No
assert metadata["body"] is None
assert metadata["headers"] == {"content-type": "application/json"}
assert metadata["status_code"] == 400
def test_filter_invocation_params_for_tracing() -> None:
"""Test that large fields are filtered from invocation params for tracing."""
params = {
"temperature": 0.7,
"tools": [{"name": "test_tool"}],
"functions": [{"name": "test_function"}],
"messages": [{"role": "system", "content": "test"}],
"response_format": {"type": "json_object"},
}
filtered = _filter_invocation_params_for_tracing(params)
# Should include temperature
assert "temperature" in filtered
assert filtered["temperature"] == 0.7
# Should exclude these large fields
assert "tools" not in filtered
assert "functions" not in filtered
assert "messages" not in filtered
assert "response_format" not in filtered
class FakeChatModelWithInvocationParams(SimpleChatModel):
"""Fake chat model with invocation params for testing tracing."""
temperature: float = 0.7
@property
@override
def _llm_type(self) -> str:
return "fake-chat-model-with-invocation-params"
@property
@override
def _identifying_params(self) -> dict[str, Any]:
return {
"temperature": self.temperature,
"tools": [{"name": "test_tool"}],
"functions": [{"name": "test_function"}],
"messages": [{"role": "system", "content": "test"}],
"response_format": {"type": "json_object"},
}
@override
def _call(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
return "test response"
def test_invocation_params_passed_to_tracer_metadata() -> None:
"""Test that invocation params are passed to tracer metadata."""
llm = FakeChatModelWithInvocationParams()
with collect_runs() as cb:
llm.invoke([HumanMessage(content="Hello")], config={"callbacks": [cb]})
assert len(cb.traced_runs) == 1
run = cb.traced_runs[0]
# The invocation params should be in the run's extra
assert run.extra == {
"batch_size": 1,
"invocation_params": {
"_type": "fake-chat-model-with-invocation-params",
"functions": [{"name": "test_function"}],
"messages": [{"content": "test", "role": "system"}],
"response_format": {"type": "json_object"},
"stop": None,
"temperature": 0.7,
"tools": [{"name": "test_tool"}],
},
"metadata": {
"ls_integration": "langchain_chat_model",
"ls_model_type": "chat",
"ls_provider": "fakechatmodelwithinvocationparams",
"ls_temperature": 0.7,
},
"options": {"stop": None},
}

View File

@@ -13,6 +13,7 @@ from langchain_core.language_models import (
BaseLLM,
FakeListLLM,
)
from langchain_core.language_models._utils import _filter_invocation_params_for_tracing
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
@@ -284,3 +285,94 @@ def test_get_ls_params() -> None:
ls_params = llm._get_ls_params(stop=["stop"])
assert ls_params["ls_stop"] == ["stop"]
def test_filter_invocation_params_for_tracing() -> None:
"""Test that large fields are filtered from invocation params for tracing."""
params = {
"temperature": 0.7,
"tools": [{"name": "test_tool"}],
"functions": [{"name": "test_function"}],
"messages": [{"role": "system", "content": "test"}],
"response_format": {"type": "json_object"},
}
filtered = _filter_invocation_params_for_tracing(params)
# Should include temperature
assert "temperature" in filtered
assert filtered["temperature"] == 0.7
# Should exclude these large fields
assert "tools" not in filtered
assert "functions" not in filtered
assert "messages" not in filtered
assert "response_format" not in filtered
class FakeLLMWithInvocationParams(BaseLLM):
"""Fake LLM with invocation params for testing tracing."""
temperature: float = 0.7
@property
@override
def _llm_type(self) -> str:
return "fake-llm-with-invocation-params"
@property
@override
def _identifying_params(self) -> dict[str, Any]:
return {
"temperature": self.temperature,
"tools": [{"name": "test_tool"}],
"functions": [{"name": "test_function"}],
"messages": [{"role": "system", "content": "test"}],
"response_format": {"type": "json_object"},
}
@override
def _generate(
self,
prompts: list[str],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> LLMResult:
generations = [[Generation(text="test response")]]
return LLMResult(generations=generations)
@override
async def _agenerate(
self,
prompts: list[str],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> LLMResult:
generations = [[Generation(text="test response")]]
return LLMResult(generations=generations)
async def test_llm_invocation_params_filtered_in_stream() -> None:
"""Test that invocation params are filtered when streaming."""
# Create a custom LLM that supports streaming
class FakeStreamingLLM(FakeLLMWithInvocationParams):
@override
def _stream(
self,
prompt: str,
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
yield GenerationChunk(text="test ")
streaming_llm = FakeStreamingLLM()
with collect_runs() as cb:
list(streaming_llm.stream("Hello", config={"callbacks": [cb]}))
assert len(cb.traced_runs) == 1
run = cb.traced_runs[0]
# Verify the run was traced
assert run.extra is not None