mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-16 18:02:57 +00:00
Compare commits
5 Commits
langchain-
...
jacob/trac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03241dc80c | ||
|
|
9c4de86bca | ||
|
|
bc0e99d045 | ||
|
|
91b1ef049c | ||
|
|
c993ba06bb |
@@ -2,6 +2,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@@ -14,6 +15,21 @@ from langchain_core.messages.content import (
|
||||
)
|
||||
|
||||
|
||||
def _filter_invocation_params_for_tracing(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Filter out large/inappropriate fields from invocation params for tracing.
|
||||
|
||||
Removes fields like tools, functions, messages, response_format that can be large.
|
||||
|
||||
Args:
|
||||
params: The invocation parameters to filter.
|
||||
|
||||
Returns:
|
||||
The filtered parameters with large fields removed.
|
||||
"""
|
||||
excluded_keys = {"tools", "functions", "messages", "response_format"}
|
||||
return {k: v for k, v in params.items() if k not in excluded_keys}
|
||||
|
||||
|
||||
def is_openai_data_block(
|
||||
block: dict, filter_: Literal["image", "audio", "file"] | None = None
|
||||
) -> bool:
|
||||
|
||||
@@ -25,6 +25,7 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models._utils import (
|
||||
_filter_invocation_params_for_tracing,
|
||||
_normalize_messages,
|
||||
_update_message_content_to_blocks,
|
||||
)
|
||||
@@ -567,6 +568,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
(run_manager,) = callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
@@ -695,6 +699,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_chat_model_start(
|
||||
self._serialized,
|
||||
@@ -972,6 +979,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
messages_to_trace = [
|
||||
_format_for_tracing(message_list) for message_list in messages
|
||||
@@ -1095,6 +1105,9 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
|
||||
messages_to_trace = [
|
||||
|
||||
@@ -42,6 +42,7 @@ from langchain_core.callbacks import (
|
||||
Callbacks,
|
||||
)
|
||||
from langchain_core.globals import get_llm_cache
|
||||
from langchain_core.language_models._utils import _filter_invocation_params_for_tracing
|
||||
from langchain_core.language_models.base import (
|
||||
BaseLanguageModel,
|
||||
LangSmithParams,
|
||||
@@ -537,6 +538,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
(run_manager,) = callback_manager.on_llm_start(
|
||||
self._serialized,
|
||||
@@ -607,6 +611,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
inheritable_metadata,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
(run_manager,) = await callback_manager.on_llm_start(
|
||||
self._serialized,
|
||||
@@ -950,6 +957,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name_list = run_name or cast(
|
||||
"list[str | None]", ([None] * len(prompts))
|
||||
)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
callback,
|
||||
@@ -959,6 +968,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
meta,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
for callback, tag, meta in zip(
|
||||
callbacks, tags_list, metadata_list, strict=False
|
||||
@@ -966,6 +978,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
CallbackManager.configure(
|
||||
cast("Callbacks", callbacks),
|
||||
@@ -975,12 +989,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
cast("dict[str, Any]", metadata),
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast("str | None", run_name)] * len(prompts)
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
existing_prompts,
|
||||
@@ -1214,6 +1229,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
run_name_list = run_name or cast(
|
||||
"list[str | None]", ([None] * len(prompts))
|
||||
)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
callback,
|
||||
@@ -1223,6 +1240,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
meta,
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
for callback, tag, meta in zip(
|
||||
callbacks, tags_list, metadata_list, strict=False
|
||||
@@ -1230,6 +1250,8 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
]
|
||||
else:
|
||||
# We've received a single callbacks arg to apply to all inputs
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
callback_managers = [
|
||||
AsyncCallbackManager.configure(
|
||||
cast("Callbacks", callbacks),
|
||||
@@ -1239,12 +1261,13 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self.tags,
|
||||
cast("dict[str, Any]", metadata),
|
||||
self.metadata,
|
||||
langsmith_inheritable_metadata=_filter_invocation_params_for_tracing(
|
||||
params
|
||||
),
|
||||
)
|
||||
] * len(prompts)
|
||||
run_name_list = [cast("str | None", run_name)] * len(prompts)
|
||||
run_ids_list = self._get_run_ids_list(run_id, prompts)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
existing_prompts,
|
||||
|
||||
@@ -17,8 +17,14 @@ from langchain_core.language_models import (
|
||||
FakeListChatModel,
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.chat_models import _generate_response_from_error
|
||||
from langchain_core.language_models._utils import (
|
||||
_filter_invocation_params_for_tracing,
|
||||
_normalize_messages,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
SimpleChatModel,
|
||||
_generate_response_from_error,
|
||||
)
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
GenericFakeChatModel,
|
||||
@@ -1390,3 +1396,86 @@ def test_generate_response_from_error_handles_streaming_response_failure() -> No
|
||||
assert metadata["body"] is None
|
||||
assert metadata["headers"] == {"content-type": "application/json"}
|
||||
assert metadata["status_code"] == 400
|
||||
|
||||
|
||||
def test_filter_invocation_params_for_tracing() -> None:
|
||||
"""Test that large fields are filtered from invocation params for tracing."""
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"tools": [{"name": "test_tool"}],
|
||||
"functions": [{"name": "test_function"}],
|
||||
"messages": [{"role": "system", "content": "test"}],
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
filtered = _filter_invocation_params_for_tracing(params)
|
||||
|
||||
# Should include temperature
|
||||
assert "temperature" in filtered
|
||||
assert filtered["temperature"] == 0.7
|
||||
|
||||
# Should exclude these large fields
|
||||
assert "tools" not in filtered
|
||||
assert "functions" not in filtered
|
||||
assert "messages" not in filtered
|
||||
assert "response_format" not in filtered
|
||||
|
||||
|
||||
class FakeChatModelWithInvocationParams(SimpleChatModel):
|
||||
"""Fake chat model with invocation params for testing tracing."""
|
||||
|
||||
temperature: float = 0.7
|
||||
|
||||
@property
|
||||
@override
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model-with-invocation-params"
|
||||
|
||||
@property
|
||||
@override
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"tools": [{"name": "test_tool"}],
|
||||
"functions": [{"name": "test_function"}],
|
||||
"messages": [{"role": "system", "content": "test"}],
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return "test response"
|
||||
|
||||
|
||||
def test_invocation_params_passed_to_tracer_metadata() -> None:
|
||||
"""Test that invocation params are passed to tracer metadata."""
|
||||
llm = FakeChatModelWithInvocationParams()
|
||||
with collect_runs() as cb:
|
||||
llm.invoke([HumanMessage(content="Hello")], config={"callbacks": [cb]})
|
||||
assert len(cb.traced_runs) == 1
|
||||
run = cb.traced_runs[0]
|
||||
# The invocation params should be in the run's extra
|
||||
assert run.extra == {
|
||||
"batch_size": 1,
|
||||
"invocation_params": {
|
||||
"_type": "fake-chat-model-with-invocation-params",
|
||||
"functions": [{"name": "test_function"}],
|
||||
"messages": [{"content": "test", "role": "system"}],
|
||||
"response_format": {"type": "json_object"},
|
||||
"stop": None,
|
||||
"temperature": 0.7,
|
||||
"tools": [{"name": "test_tool"}],
|
||||
},
|
||||
"metadata": {
|
||||
"ls_integration": "langchain_chat_model",
|
||||
"ls_model_type": "chat",
|
||||
"ls_provider": "fakechatmodelwithinvocationparams",
|
||||
"ls_temperature": 0.7,
|
||||
},
|
||||
"options": {"stop": None},
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain_core.language_models import (
|
||||
BaseLLM,
|
||||
FakeListLLM,
|
||||
)
|
||||
from langchain_core.language_models._utils import _filter_invocation_params_for_tracing
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.callbacks import (
|
||||
@@ -284,3 +285,94 @@ def test_get_ls_params() -> None:
|
||||
|
||||
ls_params = llm._get_ls_params(stop=["stop"])
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
|
||||
def test_filter_invocation_params_for_tracing() -> None:
|
||||
"""Test that large fields are filtered from invocation params for tracing."""
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"tools": [{"name": "test_tool"}],
|
||||
"functions": [{"name": "test_function"}],
|
||||
"messages": [{"role": "system", "content": "test"}],
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
filtered = _filter_invocation_params_for_tracing(params)
|
||||
|
||||
# Should include temperature
|
||||
assert "temperature" in filtered
|
||||
assert filtered["temperature"] == 0.7
|
||||
|
||||
# Should exclude these large fields
|
||||
assert "tools" not in filtered
|
||||
assert "functions" not in filtered
|
||||
assert "messages" not in filtered
|
||||
assert "response_format" not in filtered
|
||||
|
||||
|
||||
class FakeLLMWithInvocationParams(BaseLLM):
|
||||
"""Fake LLM with invocation params for testing tracing."""
|
||||
|
||||
temperature: float = 0.7
|
||||
|
||||
@property
|
||||
@override
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-llm-with-invocation-params"
|
||||
|
||||
@property
|
||||
@override
|
||||
def _identifying_params(self) -> dict[str, Any]:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"tools": [{"name": "test_tool"}],
|
||||
"functions": [{"name": "test_function"}],
|
||||
"messages": [{"role": "system", "content": "test"}],
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = [[Generation(text="test response")]]
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
generations = [[Generation(text="test response")]]
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
|
||||
async def test_llm_invocation_params_filtered_in_stream() -> None:
|
||||
"""Test that invocation params are filtered when streaming."""
|
||||
|
||||
# Create a custom LLM that supports streaming
|
||||
class FakeStreamingLLM(FakeLLMWithInvocationParams):
|
||||
@override
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
yield GenerationChunk(text="test ")
|
||||
|
||||
streaming_llm = FakeStreamingLLM()
|
||||
|
||||
with collect_runs() as cb:
|
||||
list(streaming_llm.stream("Hello", config={"callbacks": [cb]}))
|
||||
assert len(cb.traced_runs) == 1
|
||||
run = cb.traced_runs[0]
|
||||
# Verify the run was traced
|
||||
assert run.extra is not None
|
||||
|
||||
Reference in New Issue
Block a user