rfc: trace standard tool schema directly

This commit is contained in:
Bagatur 2024-10-23 15:34:51 -07:00
parent 948e2e6322
commit 0c1a576218
3 changed files with 30 additions and 2 deletions

View File

@ -257,6 +257,7 @@ class CallbackManagerMixin:
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
@ -1070,3 +1071,11 @@ class BaseCallbackManager(CallbackManagerMixin):
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
from typing_extensions import TypedDict
class _ToolSchema(TypedDict):
name: str
description: str
parameters: dict

View File

@ -272,6 +272,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise ValueError(msg)
def _standardize_tools(self, tools: Optional[Sequence]) -> Optional[List[_ToolSchema]]:
"""Convert tools to standard format for tracing."""
if not tools:
return tools
try:
return [convert_to_openai_tool(tool) for tool in tools]
except Exception:
return None
def invoke(
self,
input: LanguageModelInput,
@ -357,6 +367,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
tools: Optional[Sequence] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
@ -367,6 +378,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
tools_to_trace = self._standardize_tools(tools)
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
@ -390,6 +402,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
tools=tools_to_trace,
)
generation: Optional[ChatGenerationChunk] = None

View File

@ -7,7 +7,7 @@ import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, Sequence
from uuid import UUID
from langsmith import Client
@ -21,6 +21,7 @@ from tenacity import (
wait_exponential_jitter,
)
from langchain_core.callbacks.base import _ToolSchema
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
@ -135,6 +136,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run.
@ -147,6 +149,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata. Defaults to None.
name: The name. Defaults to None.
tools: The tools.
kwargs: Additional keyword arguments.
Returns:
@ -155,11 +158,14 @@ class LangChainTracer(BaseTracer):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]}
if tools:
inputs["tools"] = tools
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
inputs=inputs,
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,