Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
d72cb36f14 Merge branch 'master' into bagatur/tool_description_optional 2024-10-30 17:02:52 -07:00
Bagatur
58e5b7b8fa undo 2024-10-30 13:27:29 -07:00
Bagatur
cb40fbd3be core[patch]: make Tool.description optional 2024-10-30 13:25:39 -07:00
5 changed files with 33 additions and 18 deletions

View File

@@ -348,7 +348,7 @@ class ChildTool(BaseTool):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
description: Optional[str] = None
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.

View File

@@ -32,7 +32,7 @@ from langchain_core.utils.pydantic import TypeBaseModel
class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
description: Optional[str] = ""
args_schema: Annotated[TypeBaseModel, SkipValidation()] = Field(
..., description="The tool schema."
)
@@ -185,16 +185,14 @@ class StructuredTool(BaseTool):
description_ = source_function.__doc__ or None
if description_ is None and args_schema:
description_ = args_schema.__doc__ or None
if description_ is None:
msg = "Function must have a docstring if description not provided."
raise ValueError(msg)
if description is None:
if description is None and description_ is not None:
# Only apply if using the function's docstring
description_ = textwrap.dedent(description_).strip()
# Description example:
# search_api(query: str) - Searches the API for the query.
description_ = f"{description_.strip()}"
if description_:
description_ = f"{description_.strip()}"
return cls(
name=name,
func=func,

View File

@@ -20,7 +20,7 @@ from typing import (
)
from pydantic import BaseModel
from typing_extensions import TypedDict, get_args, get_origin, is_typeddict
from typing_extensions import NotRequired, TypedDict, get_args, get_origin, is_typeddict
from langchain_core._api import deprecated
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
@@ -45,10 +45,16 @@ class FunctionDescription(TypedDict):
name: str
"""The name of the function."""
description: str
description: NotRequired[str]
"""A description of the function."""
parameters: dict
parameters: NotRequired[dict]
"""The parameters of the function."""
strict: NotRequired[Optional[bool]]
"""Whether to enable strict schema adherence when generating the function call.
If set to True, the model will follow the exact schema defined in the parameters
field. Only a subset of JSON Schema is supported when strict is True.
"""
class ToolDescription(TypedDict):
@@ -294,9 +300,8 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
tool.tool_call_schema, name=tool.name, description=tool.description
)
else:
return {
oai_function = {
"name": tool.name,
"description": tool.description,
"parameters": {
# This is a hack to get around the fact that some tools
# do not expose an args_schema, and expect an argument
@@ -310,6 +315,9 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"type": "object",
},
}
if tool.description:
oai_function["description"] = tool.description
return cast(FunctionDescription, oai_function)
@deprecated(

View File

@@ -4623,6 +4623,7 @@ async def test_tool_from_runnable() -> None:
assert await chain_tool.arun({"question": "What up"}) == await chain.ainvoke(
{"question": "What up"}
)
assert chain_tool.description
assert chain_tool.description.endswith(repr(chain))
assert _schema(chain_tool.args_schema) == chain.get_input_jsonschema()
assert _schema(chain_tool.args_schema) == {

View File

@@ -651,13 +651,21 @@ def test_tool_with_kwargs() -> None:
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring
with pytest.raises(ValueError, match="Function must have a docstring"):
"""Test error is not raised when docstring is missing."""
@tool
def search_api(query: str) -> str:
return "API result"
@tool
def search_api(query: str) -> str:
return "API result"
assert search_api.name == "search_api"
assert search_api.description is None
assert search_api.args_schema
assert search_api.args_schema.model_json_schema() == {
"properties": {"query": {"title": "Query", "type": "string"}},
"required": ["query"],
"title": "search_api",
"type": "object",
}
def test_create_tool_positional_args() -> None: