mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 07:07:34 +00:00
rfc: trace standard tool schema directly
This commit is contained in:
parent
948e2e6322
commit
0c1a576218
@ -257,6 +257,7 @@ class CallbackManagerMixin:
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
tools: Optional[Sequence[_ToolSchema]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running.
|
||||
@ -1070,3 +1071,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
|
||||
|
||||
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class _ToolSchema(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
@ -272,6 +272,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def _standardize_tools(self, tools: Optional[Sequence]) -> Optional[List[_ToolSchema]]:
|
||||
"""Convert tools to standard format for tracing."""
|
||||
if not tools:
|
||||
return tools
|
||||
|
||||
try:
|
||||
return [convert_to_openai_tool(tool) for tool in tools]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
@ -357,6 +367,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
tools: Optional[Sequence] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||
@ -367,6 +378,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
tools_to_trace = self._standardize_tools(tools)
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs}
|
||||
inheritable_metadata = {
|
||||
@ -390,6 +402,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
name=config.get("run_name"),
|
||||
run_id=config.pop("run_id", None),
|
||||
batch_size=1,
|
||||
tools=tools_to_trace,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
|
||||
|
@ -7,7 +7,7 @@ import logging
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith import Client
|
||||
@ -21,6 +21,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks.base import _ToolSchema
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
@ -135,6 +136,7 @@ class LangChainTracer(BaseTracer):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
tools: Optional[Sequence[_ToolSchema]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run.
|
||||
@ -147,6 +149,7 @@ class LangChainTracer(BaseTracer):
|
||||
parent_run_id: The parent run ID. Defaults to None.
|
||||
metadata: The metadata. Defaults to None.
|
||||
name: The name. Defaults to None.
|
||||
tools: The tools.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
@ -155,11 +158,14 @@ class LangChainTracer(BaseTracer):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]}
|
||||
if tools:
|
||||
inputs["tools"] = tools
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
||||
inputs=inputs,
|
||||
extra=kwargs,
|
||||
events=[{"name": "start", "time": start_time}],
|
||||
start_time=start_time,
|
||||
|
Loading…
Reference in New Issue
Block a user