mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-16 08:06:14 +00:00
rfc: trace standard tool schema directly
This commit is contained in:
parent
948e2e6322
commit
0c1a576218
@ -257,6 +257,7 @@ class CallbackManagerMixin:
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
|
tools: Optional[Sequence[_ToolSchema]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when a chat model starts running.
|
"""Run when a chat model starts running.
|
||||||
@ -1070,3 +1071,11 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
|
|
||||||
|
|
||||||
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
Callbacks = Optional[Union[list[BaseCallbackHandler], BaseCallbackManager]]
|
||||||
|
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
class _ToolSchema(TypedDict):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: dict
|
@ -272,6 +272,16 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def _standardize_tools(self, tools: Optional[Sequence]) -> Optional[List[_ToolSchema]]:
|
||||||
|
"""Convert tools to standard format for tracing."""
|
||||||
|
if not tools:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
return [convert_to_openai_tool(tool) for tool in tools]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
@ -357,6 +367,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
*,
|
*,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[BaseMessageChunk]:
|
) -> Iterator[BaseMessageChunk]:
|
||||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||||
@ -367,6 +378,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
else:
|
else:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
|
tools_to_trace = self._standardize_tools(tools)
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop, **kwargs}
|
options = {"stop": stop, **kwargs}
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
@ -390,6 +402,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
name=config.get("run_name"),
|
name=config.get("run_name"),
|
||||||
run_id=config.pop("run_id", None),
|
run_id=config.pop("run_id", None),
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
|
tools=tools_to_trace,
|
||||||
)
|
)
|
||||||
generation: Optional[ChatGenerationChunk] = None
|
generation: Optional[ChatGenerationChunk] = None
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import logging
|
|||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union, Sequence
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langsmith import Client
|
from langsmith import Client
|
||||||
@ -21,6 +21,7 @@ from tenacity import (
|
|||||||
wait_exponential_jitter,
|
wait_exponential_jitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.callbacks.base import _ToolSchema
|
||||||
from langchain_core.env import get_runtime_environment
|
from langchain_core.env import get_runtime_environment
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
@ -135,6 +136,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[dict[str, Any]] = None,
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
|
tools: Optional[Sequence[_ToolSchema]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Start a trace for an LLM run.
|
"""Start a trace for an LLM run.
|
||||||
@ -147,6 +149,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
parent_run_id: The parent run ID. Defaults to None.
|
parent_run_id: The parent run ID. Defaults to None.
|
||||||
metadata: The metadata. Defaults to None.
|
metadata: The metadata. Defaults to None.
|
||||||
name: The name. Defaults to None.
|
name: The name. Defaults to None.
|
||||||
|
tools: The tools.
|
||||||
kwargs: Additional keyword arguments.
|
kwargs: Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -155,11 +158,14 @@ class LangChainTracer(BaseTracer):
|
|||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
if metadata:
|
if metadata:
|
||||||
kwargs.update({"metadata": metadata})
|
kwargs.update({"metadata": metadata})
|
||||||
|
inputs = {"messages": [[dumpd(msg) for msg in batch] for batch in messages]}
|
||||||
|
if tools:
|
||||||
|
inputs["tools"] = tools
|
||||||
chat_model_run = Run(
|
chat_model_run = Run(
|
||||||
id=run_id,
|
id=run_id,
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
|
inputs=inputs,
|
||||||
extra=kwargs,
|
extra=kwargs,
|
||||||
events=[{"name": "start", "time": start_time}],
|
events=[{"name": "start", "time": start_time}],
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
Loading…
Reference in New Issue
Block a user