rfc: trace standard tool schema directly

This commit is contained in:
Bagatur 2024-10-23 15:34:51 -07:00
parent 948e2e6322
commit 0c1a576218
3 changed files with 30 additions and 2 deletions

View File

@ -257,6 +257,7 @@ class CallbackManagerMixin:
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chat model starts running. """Run when a chat model starts running.
@ -1070,3 +1071,11 @@ class BaseCallbackManager(CallbackManagerMixin):
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
from typing_extensions import TypedDict
class _ToolSchema(TypedDict):
name: str
description: str
parameters: dict

View File

@ -272,6 +272,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
raise ValueError(msg) raise ValueError(msg)
def _standardize_tools(self, tools: Optional[Sequence]) -> Optional[List[_ToolSchema]]:
"""Convert tools to standard format for tracing."""
if not tools:
return tools
try:
return [convert_to_openai_tool(tool) for tool in tools]
except Exception:
return None
def invoke( def invoke(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
@ -357,6 +367,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
config: Optional[RunnableConfig] = None, config: Optional[RunnableConfig] = None,
*, *,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
tools: Optional[Sequence] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
@ -367,6 +378,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else: else:
config = ensure_config(config) config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
tools_to_trace = self._standardize_tools(tools)
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs} options = {"stop": stop, **kwargs}
inheritable_metadata = { inheritable_metadata = {
@ -390,6 +402,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
name=config.get("run_name"), name=config.get("run_name"),
run_id=config.pop("run_id", None), run_id=config.pop("run_id", None),
batch_size=1, batch_size=1,
tools=tools_to_trace,
) )
generation: Optional[ChatGenerationChunk] = None generation: Optional[ChatGenerationChunk] = None

View File

@ -7,7 +7,7 @@ import logging
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union, Sequence
from uuid import UUID from uuid import UUID
from langsmith import Client from langsmith import Client
@ -21,6 +21,7 @@ from tenacity import (
wait_exponential_jitter, wait_exponential_jitter,
) )
from langchain_core.callbacks.base import _ToolSchema
from langchain_core.env import get_runtime_environment from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd from langchain_core.load import dumpd
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
@ -135,6 +136,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None, metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None, name: Optional[str] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any, **kwargs: Any,
) -> Run: ) -> Run:
"""Start a trace for an LLM run. """Start a trace for an LLM run.
@ -147,6 +149,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: The parent run ID. Defaults to None. parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata. Defaults to None. metadata: The metadata. Defaults to None.
name: The name. Defaults to None. name: The name. Defaults to None.
tools: The tools.
kwargs: Additional keyword arguments. kwargs: Additional keyword arguments.
Returns: Returns:
@ -155,11 +158,14 @@ class LangChainTracer(BaseTracer):
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
if metadata: if metadata:
kwargs.update({"metadata": metadata}) kwargs.update({"metadata": metadata})
inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]}
if tools:
inputs["tools"] = tools
chat_model_run = Run( chat_model_run = Run(
id=run_id, id=run_id,
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]}, inputs=inputs,
extra=kwargs, extra=kwargs,
events=[{"name": "start", "time": start_time}], events=[{"name": "start", "time": start_time}],
start_time=start_time, start_time=start_time,