mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output (#24038)
Changes: - ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type" parameter now - LLM integration packages add "type" to all the above - Tool supports ToolCall inputs that have "type" specified - Tool outputs ToolMessage when a ToolCall is passed as input - Tools can separately specify ToolMessage.content and ToolMessage.raw_output - Tools emit events for validation errors (using on_tool_error and on_tool_end) Example: ```python @tool("structured_api", response_format="content_and_raw_output") def _mock_structured_tool_with_raw_output( arg1: int, arg2: bool, arg3: Optional[dict] = None ) -> Tuple[str, dict]: """A Structured Tool""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} def test_tool_call_input_tool_message_with_raw_output() -> None: tool_call: Dict = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", "type": "tool_call", } expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123") tool = _mock_structured_tool_with_raw_output actual = tool.invoke(tool_call) assert actual == expected tool_call.pop("type") with pytest.raises(ValidationError): tool.invoke(tool_call) actual_content = tool.invoke(tool_call["args"]) assert actual_content == expected.content ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
eeb996034b
commit
5fd1e67808
@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
|
|||||||
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
|
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
|
||||||
"""Test that tools defined in this repo accept a run manager argument."""
|
"""Test that tools defined in this repo accept a run manager argument."""
|
||||||
# This wouldn't be necessary if the BaseTool had a strict API.
|
# This wouldn't be necessary if the BaseTool had a strict API.
|
||||||
if cls._run is not BaseTool._arun:
|
if cls._run is not BaseTool._run:
|
||||||
run_func = cls._run
|
run_func = cls._run
|
||||||
params = inspect.signature(run_func).parameters
|
params = inspect.signature(run_func).parameters
|
||||||
assert "run_manager" in params
|
assert "run_manager" in params
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
|
||||||
from langchain_core.utils._merge import merge_dicts, merge_obj
|
from langchain_core.utils._merge import merge_dicts, merge_obj
|
||||||
@ -146,6 +146,11 @@ class ToolCall(TypedDict):
|
|||||||
An identifier is needed to associate a tool call request with a tool
|
An identifier is needed to associate a tool call request with a tool
|
||||||
call result in events when multiple concurrent tool calls are made.
|
call result in events when multiple concurrent tool calls are made.
|
||||||
"""
|
"""
|
||||||
|
type: NotRequired[Literal["tool_call"]]
|
||||||
|
|
||||||
|
|
||||||
|
def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
|
||||||
|
return ToolCall(name=name, args=args, id=id, type="tool_call")
|
||||||
|
|
||||||
|
|
||||||
class ToolCallChunk(TypedDict):
|
class ToolCallChunk(TypedDict):
|
||||||
@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
|
|||||||
"""An identifier associated with the tool call."""
|
"""An identifier associated with the tool call."""
|
||||||
index: Optional[int]
|
index: Optional[int]
|
||||||
"""The index of the tool call in a sequence."""
|
"""The index of the tool call in a sequence."""
|
||||||
|
type: NotRequired[Literal["tool_call_chunk"]]
|
||||||
|
|
||||||
|
|
||||||
|
def tool_call_chunk(
|
||||||
|
*,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
args: Optional[str] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
index: Optional[int] = None,
|
||||||
|
) -> ToolCallChunk:
|
||||||
|
return ToolCallChunk(
|
||||||
|
name=name, args=args, id=id, index=index, type="tool_call_chunk"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidToolCall(TypedDict):
|
class InvalidToolCall(TypedDict):
|
||||||
@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
|
|||||||
"""An identifier associated with the tool call."""
|
"""An identifier associated with the tool call."""
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
"""An error message associated with the tool call."""
|
"""An error message associated with the tool call."""
|
||||||
|
type: NotRequired[Literal["invalid_tool_call"]]
|
||||||
|
|
||||||
|
|
||||||
|
def invalid_tool_call(
|
||||||
|
*,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
args: Optional[str] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
) -> InvalidToolCall:
|
||||||
|
return InvalidToolCall(
|
||||||
|
name=name, args=args, id=id, error=error, type="invalid_tool_call"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def default_tool_parser(
|
def default_tool_parser(
|
||||||
|
@ -5,6 +5,12 @@ from typing import Any, Dict, List, Optional, Type
|
|||||||
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import AIMessage, InvalidToolCall
|
from langchain_core.messages import AIMessage, InvalidToolCall
|
||||||
|
from langchain_core.messages.tool import (
|
||||||
|
invalid_tool_call,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.tool import (
|
||||||
|
tool_call as create_tool_call,
|
||||||
|
)
|
||||||
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
@ -59,6 +65,7 @@ def parse_tool_call(
|
|||||||
}
|
}
|
||||||
if return_id:
|
if return_id:
|
||||||
parsed["id"] = raw_tool_call.get("id")
|
parsed["id"] = raw_tool_call.get("id")
|
||||||
|
parsed = create_tool_call(**parsed) # type: ignore
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
@ -75,7 +82,7 @@ def make_invalid_tool_call(
|
|||||||
Returns:
|
Returns:
|
||||||
An InvalidToolCall instance with the error message.
|
An InvalidToolCall instance with the error message.
|
||||||
"""
|
"""
|
||||||
return InvalidToolCall(
|
return invalid_tool_call(
|
||||||
name=raw_tool_call["function"]["name"],
|
name=raw_tool_call["function"]["name"],
|
||||||
args=raw_tool_call["function"]["arguments"],
|
args=raw_tool_call["function"]["arguments"],
|
||||||
id=raw_tool_call.get("id"),
|
id=raw_tool_call.get("id"),
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
@ -34,6 +35,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -42,7 +44,7 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Annotated, get_args, get_origin
|
from typing_extensions import Annotated, cast, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -56,6 +58,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
|
from langchain_core.messages.tool import ToolCall, ToolMessage
|
||||||
from langchain_core.prompts import (
|
from langchain_core.prompts import (
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
@ -306,7 +309,7 @@ class ToolException(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(RunnableSerializable[Union[str, Dict], Any]):
|
class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]):
|
||||||
"""Interface LangChain tools must implement."""
|
"""Interface LangChain tools must implement."""
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
@ -378,6 +381,14 @@ class ChildTool(BaseTool):
|
|||||||
] = False
|
] = False
|
||||||
"""Handle the content of the ValidationError thrown."""
|
"""Handle the content of the ValidationError thrown."""
|
||||||
|
|
||||||
|
response_format: Literal["content", "content_and_raw_output"] = "content"
|
||||||
|
"""The tool response format.
|
||||||
|
|
||||||
|
If "content" then the output of the tool is interpreted as the contents of a
|
||||||
|
ToolMessage. If "content_and_raw_output" then the output is expected to be a
|
||||||
|
two-tuple corresponding to the (content, raw_output) of a ToolMessage.
|
||||||
|
"""
|
||||||
|
|
||||||
class Config(Serializable.Config):
|
class Config(Serializable.Config):
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
@ -410,46 +421,25 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, Dict],
|
input: Union[str, Dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = ensure_config(config)
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||||
return self.run(
|
return self.run(tool_input, **kwargs)
|
||||||
input,
|
|
||||||
callbacks=config.get("callbacks"),
|
|
||||||
tags=config.get("tags"),
|
|
||||||
metadata=config.get("metadata"),
|
|
||||||
run_name=config.get("run_name"),
|
|
||||||
run_id=config.pop("run_id", None),
|
|
||||||
config=config,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, Dict],
|
input: Union[str, Dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
config = ensure_config(config)
|
tool_input, kwargs = _prep_run_args(input, config, **kwargs)
|
||||||
return await self.arun(
|
return await self.arun(tool_input, **kwargs)
|
||||||
input,
|
|
||||||
callbacks=config.get("callbacks"),
|
|
||||||
tags=config.get("tags"),
|
|
||||||
metadata=config.get("metadata"),
|
|
||||||
run_name=config.get("run_name"),
|
|
||||||
run_id=config.pop("run_id", None),
|
|
||||||
config=config,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Tool ---
|
# --- Tool ---
|
||||||
|
|
||||||
def _parse_input(
|
def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]:
|
||||||
self,
|
|
||||||
tool_input: Union[str, Dict],
|
|
||||||
) -> Union[str, Dict[str, Any]]:
|
|
||||||
"""Convert tool input to pydantic model."""
|
"""Convert tool input to pydantic model."""
|
||||||
input_args = self.args_schema
|
input_args = self.args_schema
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
@ -465,7 +455,7 @@ class ChildTool(BaseTool):
|
|||||||
for k, v in result.dict().items()
|
for k, v in result.dict().items()
|
||||||
if k in tool_input
|
if k in tool_input
|
||||||
}
|
}
|
||||||
return tool_input
|
return tool_input
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||||
@ -479,30 +469,27 @@ class ChildTool(BaseTool):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool.
|
"""Use the tool.
|
||||||
|
|
||||||
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
Add run_manager: Optional[CallbackManagerForToolRun] = None
|
||||||
to child implementations to enable tracing,
|
to child implementations to enable tracing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool asynchronously.
|
"""Use the tool asynchronously.
|
||||||
|
|
||||||
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None
|
||||||
to child implementations to enable tracing,
|
to child implementations to enable tracing.
|
||||||
"""
|
"""
|
||||||
|
if kwargs.get("run_manager") and signature(self._run).parameters.get(
|
||||||
|
"run_manager"
|
||||||
|
):
|
||||||
|
kwargs["run_manager"] = kwargs["run_manager"].get_sync()
|
||||||
return await run_in_executor(None, self._run, *args, **kwargs)
|
return await run_in_executor(None, self._run, *args, **kwargs)
|
||||||
|
|
||||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||||
|
tool_input = self._parse_input(tool_input)
|
||||||
# For backwards compatibility, if run_input is a string,
|
# For backwards compatibility, if run_input is a string,
|
||||||
# pass as a positional argument.
|
# pass as a positional argument.
|
||||||
if isinstance(tool_input, str):
|
if isinstance(tool_input, str):
|
||||||
@ -523,24 +510,20 @@ class ChildTool(BaseTool):
|
|||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
tool_call_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool."""
|
"""Run the tool."""
|
||||||
if not self.verbose and verbose is not None:
|
|
||||||
verbose_ = verbose
|
|
||||||
else:
|
|
||||||
verbose_ = self.verbose
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks,
|
callbacks,
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
verbose_,
|
self.verbose or bool(verbose),
|
||||||
tags,
|
tags,
|
||||||
self.tags,
|
self.tags,
|
||||||
metadata,
|
metadata,
|
||||||
self.metadata,
|
self.metadata,
|
||||||
)
|
)
|
||||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
|
||||||
new_arg_supported = signature(self._run).parameters.get("run_manager")
|
|
||||||
run_manager = callback_manager.on_tool_start(
|
run_manager = callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||||
@ -550,67 +533,52 @@ class ChildTool(BaseTool):
|
|||||||
# Inputs by definition should always be dicts.
|
# Inputs by definition should always be dicts.
|
||||||
# For now, it's unclear whether this assumption is ever violated,
|
# For now, it's unclear whether this assumption is ever violated,
|
||||||
# but if it is we will send a `None` value to the callback instead
|
# but if it is we will send a `None` value to the callback instead
|
||||||
# And will need to address issue via a patch.
|
# TODO: will need to address issue via a patch.
|
||||||
inputs=None if isinstance(tool_input, str) else tool_input,
|
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
content = None
|
||||||
|
raw_output = None
|
||||||
|
error_to_raise: Union[Exception, KeyboardInterrupt, None] = None
|
||||||
try:
|
try:
|
||||||
child_config = patch_config(
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(),
|
|
||||||
)
|
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
parsed_input = self._parse_input(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
if signature(self._run).parameters.get("run_manager"):
|
||||||
observation = (
|
tool_kwargs["run_manager"] = run_manager
|
||||||
context.run(
|
response = context.run(self._run, *tool_args, **tool_kwargs)
|
||||||
self._run, *tool_args, run_manager=run_manager, **tool_kwargs
|
if self.response_format == "content_and_raw_output":
|
||||||
)
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
if new_arg_supported
|
raise ValueError(
|
||||||
else context.run(self._run, *tool_args, **tool_kwargs)
|
"Since response_format='content_and_raw_output' "
|
||||||
)
|
"a two-tuple of the message content and raw tool output is "
|
||||||
|
f"expected. Instead generated response of type: "
|
||||||
|
f"{type(response)}."
|
||||||
|
)
|
||||||
|
content, raw_output = response
|
||||||
|
else:
|
||||||
|
content = response
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
raise e
|
error_to_raise = e
|
||||||
elif isinstance(self.handle_validation_error, bool):
|
|
||||||
observation = "Tool input validation error"
|
|
||||||
elif isinstance(self.handle_validation_error, str):
|
|
||||||
observation = self.handle_validation_error
|
|
||||||
elif callable(self.handle_validation_error):
|
|
||||||
observation = self.handle_validation_error(e)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
|
||||||
f"str or callable. Received: {self.handle_validation_error}"
|
|
||||||
)
|
|
||||||
return observation
|
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
run_manager.on_tool_error(e)
|
error_to_raise = e
|
||||||
raise e
|
|
||||||
elif isinstance(self.handle_tool_error, bool):
|
|
||||||
if e.args:
|
|
||||||
observation = e.args[0]
|
|
||||||
else:
|
|
||||||
observation = "Tool execution error"
|
|
||||||
elif isinstance(self.handle_tool_error, str):
|
|
||||||
observation = self.handle_tool_error
|
|
||||||
elif callable(self.handle_tool_error):
|
|
||||||
observation = self.handle_tool_error(e)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
|
||||||
f"or callable. Received: {self.handle_tool_error}"
|
|
||||||
)
|
|
||||||
run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs)
|
|
||||||
return observation
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
run_manager.on_tool_error(e)
|
error_to_raise = e
|
||||||
raise e
|
|
||||||
else:
|
if error_to_raise:
|
||||||
run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs)
|
run_manager.on_tool_error(error_to_raise)
|
||||||
return observation
|
raise error_to_raise
|
||||||
|
output = _format_output(content, raw_output, tool_call_id)
|
||||||
|
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
async def arun(
|
async def arun(
|
||||||
self,
|
self,
|
||||||
@ -625,99 +593,80 @@ class ChildTool(BaseTool):
|
|||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
run_id: Optional[uuid.UUID] = None,
|
run_id: Optional[uuid.UUID] = None,
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
|
tool_call_id: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run the tool asynchronously."""
|
"""Run the tool asynchronously."""
|
||||||
if not self.verbose and verbose is not None:
|
|
||||||
verbose_ = verbose
|
|
||||||
else:
|
|
||||||
verbose_ = self.verbose
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks,
|
callbacks,
|
||||||
self.callbacks,
|
self.callbacks,
|
||||||
verbose_,
|
self.verbose or bool(verbose),
|
||||||
tags,
|
tags,
|
||||||
self.tags,
|
self.tags,
|
||||||
metadata,
|
metadata,
|
||||||
self.metadata,
|
self.metadata,
|
||||||
)
|
)
|
||||||
new_arg_supported = signature(self._arun).parameters.get("run_manager")
|
|
||||||
run_manager = await callback_manager.on_tool_start(
|
run_manager = await callback_manager.on_tool_start(
|
||||||
{"name": self.name, "description": self.description},
|
{"name": self.name, "description": self.description},
|
||||||
tool_input if isinstance(tool_input, str) else str(tool_input),
|
tool_input if isinstance(tool_input, str) else str(tool_input),
|
||||||
color=start_color,
|
color=start_color,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
inputs=tool_input,
|
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
# Inputs by definition should always be dicts.
|
||||||
|
# For now, it's unclear whether this assumption is ever violated,
|
||||||
|
# but if it is we will send a `None` value to the callback instead
|
||||||
|
# TODO: will need to address issue via a patch.
|
||||||
|
inputs=tool_input if isinstance(tool_input, dict) else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
content = None
|
||||||
|
raw_output = None
|
||||||
|
error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None
|
||||||
try:
|
try:
|
||||||
parsed_input = self._parse_input(tool_input)
|
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
|
||||||
# We then call the tool on the tool input to get an observation
|
child_config = patch_config(config, callbacks=run_manager.get_child())
|
||||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
|
||||||
child_config = patch_config(
|
|
||||||
config,
|
|
||||||
callbacks=run_manager.get_child(),
|
|
||||||
)
|
|
||||||
context = copy_context()
|
context = copy_context()
|
||||||
context.run(_set_config_context, child_config)
|
context.run(_set_config_context, child_config)
|
||||||
coro = (
|
if self.__class__._arun is BaseTool._arun or signature(
|
||||||
context.run(
|
self._arun
|
||||||
self._arun, *tool_args, run_manager=run_manager, **tool_kwargs
|
).parameters.get("run_manager"):
|
||||||
)
|
tool_kwargs["run_manager"] = run_manager
|
||||||
if new_arg_supported
|
coro = context.run(self._arun, *tool_args, **tool_kwargs)
|
||||||
else context.run(self._arun, *tool_args, **tool_kwargs)
|
|
||||||
)
|
|
||||||
if accepts_context(asyncio.create_task):
|
if accepts_context(asyncio.create_task):
|
||||||
observation = await asyncio.create_task(coro, context=context) # type: ignore
|
response = await asyncio.create_task(coro, context=context) # type: ignore
|
||||||
else:
|
else:
|
||||||
observation = await coro
|
response = await coro
|
||||||
|
if self.response_format == "content_and_raw_output":
|
||||||
|
if not isinstance(response, tuple) or len(response) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Since response_format='content_and_raw_output' "
|
||||||
|
"a two-tuple of the message content and raw tool output is "
|
||||||
|
f"expected. Instead generated response of type: "
|
||||||
|
f"{type(response)}."
|
||||||
|
)
|
||||||
|
content, raw_output = response
|
||||||
|
else:
|
||||||
|
content = response
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if not self.handle_validation_error:
|
if not self.handle_validation_error:
|
||||||
raise e
|
error_to_raise = e
|
||||||
elif isinstance(self.handle_validation_error, bool):
|
|
||||||
observation = "Tool input validation error"
|
|
||||||
elif isinstance(self.handle_validation_error, str):
|
|
||||||
observation = self.handle_validation_error
|
|
||||||
elif callable(self.handle_validation_error):
|
|
||||||
observation = self.handle_validation_error(e)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
content = _handle_validation_error(e, flag=self.handle_validation_error)
|
||||||
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
|
||||||
f"str or callable. Received: {self.handle_validation_error}"
|
|
||||||
)
|
|
||||||
return observation
|
|
||||||
except ToolException as e:
|
except ToolException as e:
|
||||||
if not self.handle_tool_error:
|
if not self.handle_tool_error:
|
||||||
await run_manager.on_tool_error(e)
|
error_to_raise = e
|
||||||
raise e
|
|
||||||
elif isinstance(self.handle_tool_error, bool):
|
|
||||||
if e.args:
|
|
||||||
observation = e.args[0]
|
|
||||||
else:
|
|
||||||
observation = "Tool execution error"
|
|
||||||
elif isinstance(self.handle_tool_error, str):
|
|
||||||
observation = self.handle_tool_error
|
|
||||||
elif callable(self.handle_tool_error):
|
|
||||||
observation = self.handle_tool_error(e)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
content = _handle_tool_error(e, flag=self.handle_tool_error)
|
||||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
|
||||||
f"or callable. Received: {self.handle_tool_error}"
|
|
||||||
)
|
|
||||||
await run_manager.on_tool_end(
|
|
||||||
observation, color="red", name=self.name, **kwargs
|
|
||||||
)
|
|
||||||
return observation
|
|
||||||
except (Exception, KeyboardInterrupt) as e:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
await run_manager.on_tool_error(e)
|
error_to_raise = e
|
||||||
raise e
|
|
||||||
else:
|
if error_to_raise:
|
||||||
await run_manager.on_tool_end(
|
await run_manager.on_tool_error(error_to_raise)
|
||||||
observation, color=color, name=self.name, **kwargs
|
raise error_to_raise
|
||||||
)
|
|
||||||
return observation
|
output = _format_output(content, raw_output, tool_call_id)
|
||||||
|
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
|
@deprecated("0.1.47", alternative="invoke", removal="0.3.0")
|
||||||
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str:
|
||||||
@ -738,7 +687,7 @@ class Tool(BaseTool):
|
|||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, Dict],
|
input: Union[str, Dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -780,17 +729,10 @@ class Tool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
if self.func:
|
if self.func:
|
||||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||||
return (
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
self.func(
|
return self.func(*args, **kwargs)
|
||||||
*args,
|
raise NotImplementedError("Tool does not support sync invocation.")
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if new_argument_supported
|
|
||||||
else self.func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
raise NotImplementedError("Tool does not support sync")
|
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
@ -800,26 +742,13 @@ class Tool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||||
"callbacks"
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
)
|
return await self.coroutine(*args, **kwargs)
|
||||||
return (
|
|
||||||
await self.coroutine(
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||||
*args,
|
# None.
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if new_argument_supported
|
|
||||||
else await self.coroutine(*args, **kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return await run_in_executor(
|
|
||||||
None,
|
|
||||||
self._run,
|
|
||||||
run_manager=run_manager.get_sync() if run_manager else None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this is for backwards compatibility, remove in future
|
# TODO: this is for backwards compatibility, remove in future
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -870,9 +799,10 @@ class StructuredTool(BaseTool):
|
|||||||
|
|
||||||
# --- Runnable ---
|
# --- Runnable ---
|
||||||
|
|
||||||
|
# TODO: Is this needed?
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, Dict],
|
input: Union[str, Dict, ToolCall],
|
||||||
config: Optional[RunnableConfig] = None,
|
config: Optional[RunnableConfig] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -897,45 +827,26 @@ class StructuredTool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
if self.func:
|
if self.func:
|
||||||
new_argument_supported = signature(self.func).parameters.get("callbacks")
|
if run_manager and signature(self.func).parameters.get("callbacks"):
|
||||||
return (
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
self.func(
|
return self.func(*args, **kwargs)
|
||||||
*args,
|
raise NotImplementedError("StructuredTool does not support sync invocation.")
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if new_argument_supported
|
|
||||||
else self.func(*args, **kwargs)
|
|
||||||
)
|
|
||||||
raise NotImplementedError("Tool does not support sync")
|
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> Any:
|
||||||
"""Use the tool asynchronously."""
|
"""Use the tool asynchronously."""
|
||||||
if self.coroutine:
|
if self.coroutine:
|
||||||
new_argument_supported = signature(self.coroutine).parameters.get(
|
if run_manager and signature(self.coroutine).parameters.get("callbacks"):
|
||||||
"callbacks"
|
kwargs["callbacks"] = run_manager.get_child()
|
||||||
)
|
return await self.coroutine(*args, **kwargs)
|
||||||
return (
|
|
||||||
await self.coroutine(
|
# NOTE: this code is unreachable since _arun is only called if coroutine is not
|
||||||
*args,
|
# None.
|
||||||
callbacks=run_manager.get_child() if run_manager else None,
|
return await super()._arun(*args, run_manager=run_manager, **kwargs)
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if new_argument_supported
|
|
||||||
else await self.coroutine(*args, **kwargs)
|
|
||||||
)
|
|
||||||
return await run_in_executor(
|
|
||||||
None,
|
|
||||||
self._run,
|
|
||||||
run_manager=run_manager.get_sync() if run_manager else None,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_function(
|
def from_function(
|
||||||
@ -947,6 +858,8 @@ class StructuredTool(BaseTool):
|
|||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
|
*,
|
||||||
|
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = False,
|
error_on_invalid_docstring: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -963,6 +876,10 @@ class StructuredTool(BaseTool):
|
|||||||
return_direct: Whether to return the result directly or as a callback
|
return_direct: Whether to return the result directly or as a callback
|
||||||
args_schema: The schema of the tool's input arguments
|
args_schema: The schema of the tool's input arguments
|
||||||
infer_schema: Whether to infer the schema from the function's signature
|
infer_schema: Whether to infer the schema from the function's signature
|
||||||
|
response_format: The tool response format. If "content" then the output of
|
||||||
|
the tool is interpreted as the contents of a ToolMessage. If
|
||||||
|
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||||
|
corresponding to the (content, raw_output) of a ToolMessage.
|
||||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt
|
||||||
to parse parameter descriptions from Google Style function docstrings.
|
to parse parameter descriptions from Google Style function docstrings.
|
||||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||||
@ -1020,6 +937,7 @@ class StructuredTool(BaseTool):
|
|||||||
args_schema=_args_schema, # type: ignore[arg-type]
|
args_schema=_args_schema, # type: ignore[arg-type]
|
||||||
description=description_,
|
description=description_,
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
|
response_format=response_format,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1029,6 +947,7 @@ def tool(
|
|||||||
return_direct: bool = False,
|
return_direct: bool = False,
|
||||||
args_schema: Optional[Type[BaseModel]] = None,
|
args_schema: Optional[Type[BaseModel]] = None,
|
||||||
infer_schema: bool = True,
|
infer_schema: bool = True,
|
||||||
|
response_format: Literal["content", "content_and_raw_output"] = "content",
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = True,
|
error_on_invalid_docstring: bool = True,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
@ -1042,6 +961,10 @@ def tool(
|
|||||||
infer_schema: Whether to infer the schema of the arguments from
|
infer_schema: Whether to infer the schema of the arguments from
|
||||||
the function's signature. This also makes the resultant tool
|
the function's signature. This also makes the resultant tool
|
||||||
accept a dictionary input to its `run()` function.
|
accept a dictionary input to its `run()` function.
|
||||||
|
response_format: The tool response format. If "content" then the output of
|
||||||
|
the tool is interpreted as the contents of a ToolMessage. If
|
||||||
|
"content_and_raw_output" then the output is expected to be a two-tuple
|
||||||
|
corresponding to the (content, raw_output) of a ToolMessage.
|
||||||
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to
|
||||||
parse parameter descriptions from Google Style function docstrings.
|
parse parameter descriptions from Google Style function docstrings.
|
||||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configures
|
||||||
@ -1064,8 +987,12 @@ def tool(
|
|||||||
# Searches the API for the query.
|
# Searches the API for the query.
|
||||||
return
|
return
|
||||||
|
|
||||||
.. versionadded:: 0.2.14
|
@tool(response_format="content_and_raw_output")
|
||||||
Parse Google-style docstrings:
|
def search_api(query: str) -> Tuple[str, dict]:
|
||||||
|
return "partial json of results", {"full": "object of results"}
|
||||||
|
|
||||||
|
.. versionadded:: 0.2.14
|
||||||
|
Parse Google-style docstrings:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -1179,6 +1106,7 @@ def tool(
|
|||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
args_schema=schema,
|
args_schema=schema,
|
||||||
infer_schema=infer_schema,
|
infer_schema=infer_schema,
|
||||||
|
response_format=response_format,
|
||||||
parse_docstring=parse_docstring,
|
parse_docstring=parse_docstring,
|
||||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||||
)
|
)
|
||||||
@ -1195,6 +1123,7 @@ def tool(
|
|||||||
description=f"{tool_name} tool",
|
description=f"{tool_name} tool",
|
||||||
return_direct=return_direct,
|
return_direct=return_direct,
|
||||||
coroutine=coroutine,
|
coroutine=coroutine,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _make_tool
|
return _make_tool
|
||||||
@ -1350,6 +1279,103 @@ class BaseToolkit(BaseModel, ABC):
|
|||||||
"""Get the tools in the toolkit."""
|
"""Get the tools in the toolkit."""
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_call(x: Any) -> bool:
|
||||||
|
return isinstance(x, dict) and x.get("type") == "tool_call"
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_validation_error(
|
||||||
|
e: ValidationError,
|
||||||
|
*,
|
||||||
|
flag: Union[Literal[True], str, Callable[[ValidationError], str]],
|
||||||
|
) -> str:
|
||||||
|
if isinstance(flag, bool):
|
||||||
|
content = "Tool input validation error"
|
||||||
|
elif isinstance(flag, str):
|
||||||
|
content = flag
|
||||||
|
elif callable(flag):
|
||||||
|
content = flag(e)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected type of `handle_validation_error`. Expected bool, "
|
||||||
|
f"str or callable. Received: {flag}"
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tool_error(
|
||||||
|
e: ToolException,
|
||||||
|
*,
|
||||||
|
flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
|
||||||
|
) -> str:
|
||||||
|
if isinstance(flag, bool):
|
||||||
|
if e.args:
|
||||||
|
content = e.args[0]
|
||||||
|
else:
|
||||||
|
content = "Tool execution error"
|
||||||
|
elif isinstance(flag, str):
|
||||||
|
content = flag
|
||||||
|
elif callable(flag):
|
||||||
|
content = flag(e)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||||
|
f"or callable. Received: {flag}"
|
||||||
|
)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _prep_run_args(
|
||||||
|
input: Union[str, dict, ToolCall],
|
||||||
|
config: Optional[RunnableConfig],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[Union[str, Dict], Dict]:
|
||||||
|
config = ensure_config(config)
|
||||||
|
if _is_tool_call(input):
|
||||||
|
tool_call_id: Optional[str] = cast(ToolCall, input)["id"]
|
||||||
|
tool_input: Union[str, dict] = cast(ToolCall, input)["args"]
|
||||||
|
else:
|
||||||
|
tool_call_id = None
|
||||||
|
tool_input = cast(Union[str, dict], input)
|
||||||
|
return (
|
||||||
|
tool_input,
|
||||||
|
dict(
|
||||||
|
callbacks=config.get("callbacks"),
|
||||||
|
tags=config.get("tags"),
|
||||||
|
metadata=config.get("metadata"),
|
||||||
|
run_name=config.get("run_name"),
|
||||||
|
run_id=config.pop("run_id", None),
|
||||||
|
config=config,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_output(
|
||||||
|
content: Any, raw_output: Any, tool_call_id: Optional[str]
|
||||||
|
) -> Union[ToolMessage, Any]:
|
||||||
|
if tool_call_id:
|
||||||
|
# NOTE: This will fail to stringify lists which aren't actually content blocks
|
||||||
|
# but whose first element happens to be a string or dict. Tools should avoid
|
||||||
|
# returning such contents.
|
||||||
|
if not isinstance(content, str) and not (
|
||||||
|
isinstance(content, list)
|
||||||
|
and content
|
||||||
|
and isinstance(content[0], (str, dict))
|
||||||
|
):
|
||||||
|
content = _stringify(content)
|
||||||
|
return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id)
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify(content: Any) -> str:
|
||||||
|
try:
|
||||||
|
return json.dumps(content)
|
||||||
|
except Exception:
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
def _get_description_from_runnable(runnable: Runnable) -> str:
|
def _get_description_from_runnable(runnable: Runnable) -> str:
|
||||||
"""Generate a placeholder description of a runnable."""
|
"""Generate a placeholder description of a runnable."""
|
||||||
input_schema = runnable.input_schema.schema()
|
input_schema = runnable.input_schema.schema()
|
||||||
|
@ -317,6 +317,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -419,6 +426,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -908,6 +922,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -1010,6 +1031,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
|
@ -674,6 +674,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -776,6 +783,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
|
@ -5577,6 +5577,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -5701,6 +5708,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -6237,6 +6251,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -6361,6 +6382,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -6834,6 +6862,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -6936,6 +6971,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -7444,6 +7486,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -7568,6 +7617,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -8068,6 +8124,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -8203,6 +8266,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -8683,6 +8753,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -8785,6 +8862,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -9238,6 +9322,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -9340,6 +9431,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -9880,6 +9978,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'invalid_tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
@ -10004,6 +10109,13 @@
|
|||||||
'title': 'Name',
|
'title': 'Name',
|
||||||
'type': 'string',
|
'type': 'string',
|
||||||
}),
|
}),
|
||||||
|
'type': dict({
|
||||||
|
'enum': list([
|
||||||
|
'tool_call',
|
||||||
|
]),
|
||||||
|
'title': 'Type',
|
||||||
|
'type': 'string',
|
||||||
|
}),
|
||||||
}),
|
}),
|
||||||
'required': list([
|
'required': list([
|
||||||
'name',
|
'name',
|
||||||
|
@ -8,7 +8,7 @@ import textwrap
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing_extensions import Annotated, TypedDict
|
from typing_extensions import Annotated, TypedDict
|
||||||
@ -17,6 +17,7 @@ from langchain_core.callbacks import (
|
|||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
from langchain_core.runnables import Runnable, RunnableLambda, ensure_config
|
||||||
from langchain_core.tools import (
|
from langchain_core.tools import (
|
||||||
@ -1067,6 +1068,65 @@ def test_tool_annotated_descriptions() -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_call_input_tool_message_output() -> None:
|
||||||
|
tool_call = {
|
||||||
|
"name": "structured_api",
|
||||||
|
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||||
|
"id": "123",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
tool = _MockStructuredTool()
|
||||||
|
expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123")
|
||||||
|
actual = tool.invoke(tool_call)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
tool_call.pop("type")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
tool.invoke(tool_call)
|
||||||
|
|
||||||
|
|
||||||
|
class _MockStructuredToolWithRawOutput(BaseTool):
|
||||||
|
name: str = "structured_api"
|
||||||
|
args_schema: Type[BaseModel] = _MockSchema
|
||||||
|
description: str = "A Structured Tool"
|
||||||
|
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||||
|
|
||||||
|
|
||||||
|
@tool("structured_api", response_format="content_and_raw_output")
|
||||||
|
def _mock_structured_tool_with_raw_output(
|
||||||
|
arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
"""A Structured Tool"""
|
||||||
|
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output]
|
||||||
|
)
|
||||||
|
def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None:
|
||||||
|
tool_call: Dict = {
|
||||||
|
"name": "structured_api",
|
||||||
|
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||||
|
"id": "123",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
|
||||||
|
actual = tool.invoke(tool_call)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
tool_call.pop("type")
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
tool.invoke(tool_call)
|
||||||
|
|
||||||
|
actual_content = tool.invoke(tool_call["args"])
|
||||||
|
assert actual_content == expected.content
|
||||||
|
|
||||||
|
|
||||||
def test_convert_from_runnable_dict() -> None:
|
def test_convert_from_runnable_dict() -> None:
|
||||||
# Test with typed dict input
|
# Test with typed dict input
|
||||||
class Args(TypedDict):
|
class Args(TypedDict):
|
||||||
|
26
libs/langchain/poetry.lock
generated
26
libs/langchain/poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohttp"
|
name = "aiohttp"
|
||||||
@ -1760,7 +1760,7 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.2.12"
|
version = "0.2.13"
|
||||||
description = "Building applications with LLMs through composability"
|
description = "Building applications with LLMs through composability"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -1784,7 +1784,7 @@ url = "../core"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-openai"
|
name = "langchain-openai"
|
||||||
version = "0.1.14"
|
version = "0.1.15"
|
||||||
description = "An integration package connecting OpenAI and LangChain"
|
description = "An integration package connecting OpenAI and LangChain"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
@ -1792,7 +1792,7 @@ files = []
|
|||||||
develop = true
|
develop = true
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
langchain-core = ">=0.2.2,<0.3"
|
langchain-core = "^0.2.13"
|
||||||
openai = "^1.32.0"
|
openai = "^1.32.0"
|
||||||
tiktoken = ">=0.7,<1"
|
tiktoken = ">=0.7,<1"
|
||||||
|
|
||||||
@ -1834,13 +1834,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langsmith"
|
name = "langsmith"
|
||||||
version = "0.1.84"
|
version = "0.1.85"
|
||||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<4.0,>=3.8.1"
|
python-versions = "<4.0,>=3.8.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "langsmith-0.1.84-py3-none-any.whl", hash = "sha256:01f3c6390dba26c583bac8dd0e551ce3d0509c7f55cad714db0b5c8d36e4c7ff"},
|
{file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"},
|
||||||
{file = "langsmith-0.1.84.tar.gz", hash = "sha256:5220c0439838b9a5bd320fd3686be505c5083dcee22d2452006c23891153bea1"},
|
{file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -2350,13 +2350,13 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "1.35.10"
|
version = "1.35.13"
|
||||||
description = "The official Python library for the openai API"
|
description = "The official Python library for the openai API"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7.1"
|
python-versions = ">=3.7.1"
|
||||||
files = [
|
files = [
|
||||||
{file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"},
|
{file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"},
|
||||||
{file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"},
|
{file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -4141,13 +4141,13 @@ urllib3 = ">=2"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "types-setuptools"
|
name = "types-setuptools"
|
||||||
version = "70.2.0.20240704"
|
version = "70.3.0.20240710"
|
||||||
description = "Typing stubs for setuptools"
|
description = "Typing stubs for setuptools"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "types-setuptools-70.2.0.20240704.tar.gz", hash = "sha256:2f8d28d16ca1607080f9fdf19595bd49c942884b2bbd6529c9b8a9a8fc8db911"},
|
{file = "types-setuptools-70.3.0.20240710.tar.gz", hash = "sha256:842cbf399812d2b65042c9d6ff35113bbf282dee38794779aa1f94e597bafc35"},
|
||||||
{file = "types_setuptools-70.2.0.20240704-py3-none-any.whl", hash = "sha256:6b892d5441c2ed58dd255724516e3df1db54892fb20597599aea66d04c3e4d7f"},
|
{file = "types_setuptools-70.3.0.20240710-py3-none-any.whl", hash = "sha256:bd0db2a4b9f2c49ac5564be4e0fb3125c4c46b1f73eafdcbceffa5b005cceca4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -43,6 +43,7 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
@ -1102,12 +1103,12 @@ def _make_message_chunk_from_anthropic_event(
|
|||||||
warnings.warn("Received unexpected tool content block.")
|
warnings.warn("Received unexpected tool content block.")
|
||||||
content_block = event.content_block.model_dump()
|
content_block = event.content_block.model_dump()
|
||||||
content_block["index"] = event.index
|
content_block["index"] = event.index
|
||||||
tool_call_chunk = {
|
tool_call_chunk = create_tool_call_chunk(
|
||||||
"index": event.index,
|
index=event.index,
|
||||||
"id": event.content_block.id,
|
id=event.content_block.id,
|
||||||
"name": event.content_block.name,
|
name=event.content_block.name,
|
||||||
"args": "",
|
args="",
|
||||||
}
|
)
|
||||||
message_chunk = AIMessageChunk(
|
message_chunk = AIMessageChunk(
|
||||||
content=[content_block],
|
content=[content_block],
|
||||||
tool_call_chunks=[tool_call_chunk], # type: ignore
|
tool_call_chunks=[tool_call_chunk], # type: ignore
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Any, List, Optional, Type, Union, cast
|
from typing import Any, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, ToolCall
|
from langchain_core.messages import AIMessage, ToolCall
|
||||||
|
from langchain_core.messages.tool import tool_call
|
||||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
@ -79,7 +80,7 @@ def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[Tool
|
|||||||
if block["type"] != "tool_use":
|
if block["type"] != "tool_use":
|
||||||
continue
|
continue
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
tool_call(name=block["name"], args=block["input"], id=block["id"])
|
||||||
)
|
)
|
||||||
return tool_calls
|
return tool_calls
|
||||||
else:
|
else:
|
||||||
|
@ -365,10 +365,7 @@ async def test_astreaming() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tool_use() -> None:
|
def test_tool_use() -> None:
|
||||||
llm = ChatAnthropic( # type: ignore[call-arg]
|
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
||||||
model=MODEL_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools(
|
llm_with_tools = llm.bind_tools(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@ -478,6 +475,7 @@ def test_anthropic_with_empty_text_block() -> None:
|
|||||||
"name": "type_letter",
|
"name": "type_letter",
|
||||||
"args": {"letter": "d"},
|
"args": {"letter": "d"},
|
||||||
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||||
|
"type": "tool_call",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -33,8 +33,20 @@ class _Foo2(BaseModel):
|
|||||||
def test_tools_output_parser() -> None:
|
def test_tools_output_parser() -> None:
|
||||||
output_parser = ToolsOutputParser()
|
output_parser = ToolsOutputParser()
|
||||||
expected = [
|
expected = [
|
||||||
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
|
{
|
||||||
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
|
"name": "_Foo1",
|
||||||
|
"args": {"bar": 0},
|
||||||
|
"id": "1",
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "_Foo2",
|
||||||
|
"args": {"baz": "a"},
|
||||||
|
"id": "2",
|
||||||
|
"index": 3,
|
||||||
|
"type": "tool_call",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
actual = output_parser.parse_result(_RESULT)
|
actual = output_parser.parse_result(_RESULT)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
@ -56,7 +68,13 @@ def test_tools_output_parser_args_only() -> None:
|
|||||||
|
|
||||||
def test_tools_output_parser_first_tool_only() -> None:
|
def test_tools_output_parser_first_tool_only() -> None:
|
||||||
output_parser = ToolsOutputParser(first_tool_only=True)
|
output_parser = ToolsOutputParser(first_tool_only=True)
|
||||||
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
|
expected: Any = {
|
||||||
|
"name": "_Foo1",
|
||||||
|
"args": {"bar": 0},
|
||||||
|
"id": "1",
|
||||||
|
"index": 1,
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
actual = output_parser.parse_result(_RESULT)
|
actual = output_parser.parse_result(_RESULT)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
@ -81,7 +99,14 @@ def test_tools_output_parser_empty_content() -> None:
|
|||||||
)
|
)
|
||||||
message = AIMessage(
|
message = AIMessage(
|
||||||
"",
|
"",
|
||||||
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}],
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "ChartType",
|
||||||
|
"args": {"chart_type": "pie"},
|
||||||
|
"id": "foo",
|
||||||
|
"type": "tool_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
actual = output_parser.invoke(message)
|
actual = output_parser.invoke(message)
|
||||||
expected = ChartType(chart_type="pie")
|
expected = ChartType(chart_type="pie")
|
||||||
|
@ -9,10 +9,11 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import urllib
|
import urllib
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, BinaryIO, Callable, List, Optional
|
from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -126,6 +127,8 @@ class SessionsPythonREPLTool(BaseTool):
|
|||||||
session_id: str = str(uuid4())
|
session_id: str = str(uuid4())
|
||||||
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
"""The session ID to use for the code interpreter. Defaults to a random UUID."""
|
||||||
|
|
||||||
|
response_format: Literal["content_and_raw_output"] = "content_and_raw_output"
|
||||||
|
|
||||||
def _build_url(self, path: str) -> str:
|
def _build_url(self, path: str) -> str:
|
||||||
pool_management_endpoint = self.pool_management_endpoint
|
pool_management_endpoint = self.pool_management_endpoint
|
||||||
if not pool_management_endpoint:
|
if not pool_management_endpoint:
|
||||||
@ -164,16 +167,16 @@ class SessionsPythonREPLTool(BaseTool):
|
|||||||
properties = response_json.get("properties", {})
|
properties = response_json.get("properties", {})
|
||||||
return properties
|
return properties
|
||||||
|
|
||||||
def _run(self, python_code: str) -> Any:
|
def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]:
|
||||||
response = self.execute(python_code)
|
response = self.execute(python_code)
|
||||||
|
|
||||||
# if the result is an image, remove the base64 data
|
# if the result is an image, remove the base64 data
|
||||||
result = response.get("result")
|
result = deepcopy(response.get("result"))
|
||||||
if isinstance(result, dict):
|
if isinstance(result, dict):
|
||||||
if result.get("type") == "image" and "base64_data" in result:
|
if result.get("type") == "image" and "base64_data" in result:
|
||||||
result.pop("base64_data")
|
result.pop("base64_data")
|
||||||
|
|
||||||
return json.dumps(
|
content = json.dumps(
|
||||||
{
|
{
|
||||||
"result": result,
|
"result": result,
|
||||||
"stdout": response.get("stdout"),
|
"stdout": response.get("stdout"),
|
||||||
@ -181,6 +184,7 @@ class SessionsPythonREPLTool(BaseTool):
|
|||||||
},
|
},
|
||||||
indent=2,
|
indent=2,
|
||||||
)
|
)
|
||||||
|
return content, response
|
||||||
|
|
||||||
def upload_file(
|
def upload_file(
|
||||||
self,
|
self,
|
||||||
|
@ -54,6 +54,12 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.tool import (
|
||||||
|
ToolCallChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.tool import (
|
||||||
|
tool_call_chunk as create_tool_call_chunk,
|
||||||
|
)
|
||||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
@ -199,6 +205,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
role = cast(str, _dict.get("role"))
|
role = cast(str, _dict.get("role"))
|
||||||
content = cast(str, _dict.get("content") or "")
|
content = cast(str, _dict.get("content") or "")
|
||||||
additional_kwargs: Dict = {}
|
additional_kwargs: Dict = {}
|
||||||
|
tool_call_chunks: List[ToolCallChunk] = []
|
||||||
if _dict.get("function_call"):
|
if _dict.get("function_call"):
|
||||||
function_call = dict(_dict["function_call"])
|
function_call = dict(_dict["function_call"])
|
||||||
if "name" in function_call and function_call["name"] is None:
|
if "name" in function_call and function_call["name"] is None:
|
||||||
@ -206,21 +213,18 @@ def _convert_chunk_to_message_chunk(
|
|||||||
additional_kwargs["function_call"] = function_call
|
additional_kwargs["function_call"] = function_call
|
||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
try:
|
for rtc in raw_tool_calls:
|
||||||
tool_call_chunks = [
|
try:
|
||||||
{
|
tool_call_chunks.append(
|
||||||
"name": rtc["function"].get("name"),
|
create_tool_call_chunk(
|
||||||
"args": rtc["function"].get("arguments"),
|
name=rtc["function"].get("name"),
|
||||||
"id": rtc.get("id"),
|
args=rtc["function"].get("arguments"),
|
||||||
"index": rtc["index"],
|
id=rtc.get("id"),
|
||||||
}
|
index=rtc.get("index"),
|
||||||
for rtc in raw_tool_calls
|
)
|
||||||
]
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
|
||||||
tool_call_chunks = []
|
|
||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
@ -237,7 +241,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
return AIMessageChunk(
|
return AIMessageChunk(
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
tool_call_chunks=tool_call_chunks,
|
||||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
@ -53,6 +53,7 @@ from langchain_core.messages import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
JsonOutputParser,
|
JsonOutputParser,
|
||||||
PydanticOutputParser,
|
PydanticOutputParser,
|
||||||
@ -511,19 +512,19 @@ class ChatGroq(BaseChatModel):
|
|||||||
generation = chat_result.generations[0]
|
generation = chat_result.generations[0]
|
||||||
message = cast(AIMessage, generation.message)
|
message = cast(AIMessage, generation.message)
|
||||||
tool_call_chunks = [
|
tool_call_chunks = [
|
||||||
{
|
create_tool_call_chunk(
|
||||||
"name": rtc["function"].get("name"),
|
name=rtc["function"].get("name"),
|
||||||
"args": rtc["function"].get("arguments"),
|
args=rtc["function"].get("arguments"),
|
||||||
"id": rtc.get("id"),
|
id=rtc.get("id"),
|
||||||
"index": rtc.get("index"),
|
index=rtc.get("index"),
|
||||||
}
|
)
|
||||||
for rtc in message.additional_kwargs.get("tool_calls", [])
|
for rtc in message.additional_kwargs.get("tool_calls", [])
|
||||||
]
|
]
|
||||||
chunk_ = ChatGenerationChunk(
|
chunk_ = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
content=message.content,
|
content=message.content,
|
||||||
additional_kwargs=message.additional_kwargs,
|
additional_kwargs=message.additional_kwargs,
|
||||||
tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
|
tool_call_chunks=tool_call_chunks,
|
||||||
usage_metadata=message.usage_metadata,
|
usage_metadata=message.usage_metadata,
|
||||||
),
|
),
|
||||||
generation_info=generation.generation_info,
|
generation_info=generation.generation_info,
|
||||||
|
@ -77,6 +77,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||||
|
type="tool_call",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -112,6 +113,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
args="oops",
|
args="oops",
|
||||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||||
|
type="invalid_tool_call",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@ -119,6 +121,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="call_abc123",
|
id="call_abc123",
|
||||||
|
type="tool_call",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -42,10 +42,12 @@ from langchain_core.messages import (
|
|||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
|
ToolCallChunk,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
@ -174,6 +176,7 @@ def _convert_delta_to_message_chunk(
|
|||||||
role = cast(str, _dict.get("role"))
|
role = cast(str, _dict.get("role"))
|
||||||
content = cast(str, _dict.get("content") or "")
|
content = cast(str, _dict.get("content") or "")
|
||||||
additional_kwargs: Dict = {}
|
additional_kwargs: Dict = {}
|
||||||
|
tool_call_chunks: List[ToolCallChunk] = []
|
||||||
if _dict.get("function_call"):
|
if _dict.get("function_call"):
|
||||||
function_call = dict(_dict["function_call"])
|
function_call = dict(_dict["function_call"])
|
||||||
if "name" in function_call and function_call["name"] is None:
|
if "name" in function_call and function_call["name"] is None:
|
||||||
@ -181,21 +184,18 @@ def _convert_delta_to_message_chunk(
|
|||||||
additional_kwargs["function_call"] = function_call
|
additional_kwargs["function_call"] = function_call
|
||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
if raw_tool_calls := _dict.get("tool_calls"):
|
||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
try:
|
for rtc in raw_tool_calls:
|
||||||
tool_call_chunks = [
|
try:
|
||||||
{
|
tool_call_chunks.append(
|
||||||
"name": rtc["function"].get("name"),
|
create_tool_call_chunk(
|
||||||
"args": rtc["function"].get("arguments"),
|
name=rtc["function"].get("name"),
|
||||||
"id": rtc.get("id"),
|
args=rtc["function"].get("arguments"),
|
||||||
"index": rtc["index"],
|
id=rtc.get("id"),
|
||||||
}
|
index=rtc.get("index"),
|
||||||
for rtc in raw_tool_calls
|
)
|
||||||
]
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
|
||||||
tool_call_chunks = []
|
|
||||||
|
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
@ -50,6 +50,7 @@ from langchain_core.messages import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk
|
||||||
from langchain_core.output_parsers import (
|
from langchain_core.output_parsers import (
|
||||||
JsonOutputParser,
|
JsonOutputParser,
|
||||||
PydanticOutputParser,
|
PydanticOutputParser,
|
||||||
@ -103,19 +104,10 @@ def _convert_mistral_chat_message_to_message(
|
|||||||
dict, parse_tool_call(raw_tool_call, return_id=True)
|
dict, parse_tool_call(raw_tool_call, return_id=True)
|
||||||
)
|
)
|
||||||
if not parsed["id"]:
|
if not parsed["id"]:
|
||||||
tool_call_id = uuid.uuid4().hex[:]
|
parsed["id"] = uuid.uuid4().hex[:]
|
||||||
tool_calls.append(
|
tool_calls.append(parsed)
|
||||||
{
|
|
||||||
**parsed,
|
|
||||||
**{"id": tool_call_id},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
tool_calls.append(parsed)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
invalid_tool_calls.append(
|
invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e)))
|
||||||
dict(make_invalid_tool_call(raw_tool_call, str(e)))
|
|
||||||
)
|
|
||||||
return AIMessage(
|
return AIMessage(
|
||||||
content=content,
|
content=content,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=additional_kwargs,
|
||||||
@ -206,12 +198,12 @@ def _convert_chunk_to_message_chunk(
|
|||||||
else:
|
else:
|
||||||
tool_call_id = raw_tool_call.get("id")
|
tool_call_id = raw_tool_call.get("id")
|
||||||
tool_call_chunks.append(
|
tool_call_chunks.append(
|
||||||
{
|
tool_call_chunk(
|
||||||
"name": raw_tool_call["function"].get("name"),
|
name=raw_tool_call["function"].get("name"),
|
||||||
"args": raw_tool_call["function"].get("arguments"),
|
args=raw_tool_call["function"].get("arguments"),
|
||||||
"id": tool_call_id,
|
id=tool_call_id,
|
||||||
"index": raw_tool_call.get("index"),
|
index=raw_tool_call.get("index"),
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="abc123",
|
id="abc123",
|
||||||
|
type="tool_call",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -178,6 +179,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
args="oops",
|
args="oops",
|
||||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||||
id="abc123",
|
id="abc123",
|
||||||
|
type="invalid_tool_call",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@ -185,6 +187,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="def456",
|
id="def456",
|
||||||
|
type="tool_call",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -63,6 +63,7 @@ from langchain_core.messages import (
|
|||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
|
from langchain_core.messages.tool import tool_call_chunk
|
||||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
from langchain_core.output_parsers.openai_tools import (
|
||||||
@ -244,12 +245,12 @@ def _convert_delta_to_message_chunk(
|
|||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||||
try:
|
try:
|
||||||
tool_call_chunks = [
|
tool_call_chunks = [
|
||||||
{
|
tool_call_chunk(
|
||||||
"name": rtc["function"].get("name"),
|
name=rtc["function"].get("name"),
|
||||||
"args": rtc["function"].get("arguments"),
|
args=rtc["function"].get("arguments"),
|
||||||
"id": rtc.get("id"),
|
id=rtc.get("id"),
|
||||||
"index": rtc["index"],
|
index=rtc["index"],
|
||||||
}
|
)
|
||||||
for rtc in raw_tool_calls
|
for rtc in raw_tool_calls
|
||||||
]
|
]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -117,6 +117,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||||
|
type="tool_call",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -151,6 +152,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
args="oops",
|
args="oops",
|
||||||
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
|
||||||
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
|
||||||
|
type="invalid_tool_call",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@ -158,6 +160,7 @@ def test__convert_dict_to_message_tool_call() -> None:
|
|||||||
name="GenerateUsername",
|
name="GenerateUsername",
|
||||||
args={"name": "Sally", "hair_color": "green"},
|
args={"name": "Sally", "hair_color": "green"},
|
||||||
id="call_abc123",
|
id="call_abc123",
|
||||||
|
type="tool_call",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -353,7 +356,10 @@ def test_get_num_tokens_from_messages() -> None:
|
|||||||
),
|
),
|
||||||
AIMessage("a nice bird"),
|
AIMessage("a nice bird"),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
"",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
|
||||||
|
],
|
||||||
),
|
),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
"",
|
"",
|
||||||
@ -362,7 +368,10 @@ def test_get_num_tokens_from_messages() -> None:
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
AIMessage(
|
AIMessage(
|
||||||
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
"text",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call")
|
||||||
|
],
|
||||||
),
|
),
|
||||||
ToolMessage("foobar", tool_call_id="foo"),
|
ToolMessage("foobar", tool_call_id="foo"),
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user