From 0c1a576218d78e294c5aab5f907663b81a92ac79 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 23 Oct 2024 15:34:51 -0700 Subject: [PATCH] rfc: trace standard tool schema directly --- libs/core/langchain_core/callbacks/base.py | 9 +++++++++ .../langchain_core/language_models/chat_models.py | 13 +++++++++++++ libs/core/langchain_core/tracers/langchain.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/callbacks/base.py b/libs/core/langchain_core/callbacks/base.py index b3a61aa9a2a..0e0812989c5 100644 --- a/libs/core/langchain_core/callbacks/base.py +++ b/libs/core/langchain_core/callbacks/base.py @@ -257,6 +257,7 @@ class CallbackManagerMixin: parent_run_id: Optional[UUID] = None, tags: Optional[list[str]] = None, metadata: Optional[dict[str, Any]] = None, + tools: Optional[Sequence[_ToolSchema]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running. @@ -1070,3 +1071,11 @@ class BaseCallbackManager(CallbackManagerMixin): Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]] + + +from typing_extensions import TypedDict + +class _ToolSchema(TypedDict): + name: str + description: str + parameters: dict \ No newline at end of file diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 39fd11c247f..fc994f2016f 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -272,6 +272,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) raise ValueError(msg) + def _standardize_tools(self, tools: Optional[Sequence]) -> Optional[List[_ToolSchema]]: + """Convert tools to standard format for tracing.""" + if not tools: + return tools + + try: + return [convert_to_openai_tool(tool) for tool in tools] + except Exception: + return None + def invoke( self, input: LanguageModelInput, @@ -357,6 +367,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): config: Optional[RunnableConfig] = None, *, stop: Optional[list[str]] = None, + tools: Optional[Sequence] = None, **kwargs: Any, ) -> Iterator[BaseMessageChunk]: if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): @@ -367,6 +378,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): else: config = ensure_config(config) messages = self._convert_input(input).to_messages() + tools_to_trace = self._standardize_tools(tools) params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} inheritable_metadata = { @@ -390,6 +402,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): name=config.get("run_name"), run_id=config.pop("run_id", None), batch_size=1, + tools=tools_to_trace, ) generation: Optional[ChatGenerationChunk] = None diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 0183adb2604..8a26ee3e61c 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -7,7 +7,7 @@ import logging import warnings from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, Sequence from uuid import UUID from langsmith import Client @@ -21,6 +21,7 @@ from tenacity import ( wait_exponential_jitter, ) +from langchain_core.callbacks.base import _ToolSchema from langchain_core.env import get_runtime_environment from langchain_core.load import dumpd from langchain_core.outputs import ChatGenerationChunk, GenerationChunk @@ -135,6 +136,7 @@ class LangChainTracer(BaseTracer): parent_run_id: Optional[UUID] = None, metadata: Optional[dict[str, Any]] = None, name: Optional[str] = None, + tools: Optional[Sequence[_ToolSchema]] = None, **kwargs: Any, ) -> Run: """Start a trace for an LLM run. @@ -147,6 +149,7 @@ class LangChainTracer(BaseTracer): parent_run_id: The parent run ID. Defaults to None. metadata: The metadata. Defaults to None. name: The name. Defaults to None. + tools: The tools. kwargs: Additional keyword arguments. Returns: @@ -155,11 +158,14 @@ class LangChainTracer(BaseTracer): start_time = datetime.now(timezone.utc) if metadata: kwargs.update({"metadata": metadata}) + inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]} + if tools: + inputs["tools"] = tools chat_model_run = Run( id=run_id, parent_run_id=parent_run_id, serialized=serialized, - inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, + inputs=inputs, extra=kwargs, events=[{"name": "start", "time": start_time}], start_time=start_time,