Compare commits

...

2 Commits

Author SHA1 Message Date
Bagatur
d13d101c66 fmt 2024-10-23 15:38:48 -07:00
Bagatur
0c1a576218 rfc: trace standard tool schema directly 2024-10-23 15:34:51 -07:00
3 changed files with 33 additions and 1 deletions

View File

@@ -257,6 +257,7 @@ class CallbackManagerMixin:
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running.
@@ -1070,3 +1071,12 @@ class BaseCallbackManager(CallbackManagerMixin):
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
from typing_extensions import TypedDict
class _ToolSchema(TypedDict):
name: str
description: str
parameters: dict

View File

@@ -272,6 +272,18 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise ValueError(msg)
def _standardize_tools(
self, tools: Optional[Sequence]
) -> Optional[List[_ToolSchema]]:
"""Convert tools to standard format for tracing."""
if not tools:
return tools
try:
return [convert_to_openai_tool(tool) for tool in tools]
except Exception:
return None
def invoke(
self,
input: LanguageModelInput,
@@ -357,6 +369,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
tools: Optional[Sequence] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
@@ -367,6 +380,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
tools_to_trace = self._standardize_tools(tools)
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
inheritable_metadata = {
@@ -390,6 +404,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
name=config.get("run_name"),
run_id=config.pop("run_id", None),
batch_size=1,
tools=tools_to_trace,
)
generation: Optional[ChatGenerationChunk] = None

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import copy
import logging
import warnings
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Optional, Union
@@ -21,6 +22,7 @@ from tenacity import (
wait_exponential_jitter,
)
from langchain_core.callbacks.base import _ToolSchema
from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
@@ -135,6 +137,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
tools: Optional[Sequence[_ToolSchema]] = None,
**kwargs: Any,
) -> Run:
"""Start a trace for an LLM run.
@@ -147,6 +150,7 @@ class LangChainTracer(BaseTracer):
parent_run_id: The parent run ID. Defaults to None.
metadata: The metadata. Defaults to None.
name: The name. Defaults to None.
tools: The tools.
kwargs: Additional keyword arguments.
Returns:
@@ -155,11 +159,14 @@ class LangChainTracer(BaseTracer):
start_time = datetime.now(timezone.utc)
if metadata:
kwargs.update({"metadata": metadata})
inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]}
if tools:
inputs["tools"] = tools
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
inputs=inputs,
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,