mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
core[patch]: add InjectedToolArg annotation (#24279)
```python from typing_extensions import Annotated from langchain_core.tools import tool, InjectedToolArg from langchain_anthropic import ChatAnthropic @tool def multiply(x: int, y: int, not_for_model: Annotated[dict, InjectedToolArg]) -> str: """multiply.""" return x * y ChatAnthropic(model='claude-3-sonnet-20240229',).bind_tools([multiply]).invoke('5 times 3').tool_calls ''' -> [{'name': 'multiply', 'args': {'x': 5, 'y': 3}, 'id': 'toolu_01Y1QazYWhu4R8vF4hF4z9no', 'type': 'tool_call'}] ''' ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -101,14 +101,12 @@ def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||
return get_origin(typ) is Annotated
|
||||
|
||||
|
||||
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
|
||||
def _get_annotation_description(arg_type: Type) -> str | None:
|
||||
if _is_annotated_type(arg_type):
|
||||
annotated_args = get_args(arg_type)
|
||||
arg_type = annotated_args[0]
|
||||
if len(annotated_args) > 1:
|
||||
for annotation in annotated_args[1:]:
|
||||
if isinstance(annotation, str):
|
||||
return annotation
|
||||
for annotation in annotated_args[1:]:
|
||||
if isinstance(annotation, str):
|
||||
return annotation
|
||||
return None
|
||||
|
||||
|
||||
@@ -244,7 +242,7 @@ def _infer_arg_descriptions(
|
||||
for arg, arg_type in annotations.items():
|
||||
if arg in arg_descriptions:
|
||||
continue
|
||||
if desc := _get_annotation_description(arg, arg_type):
|
||||
if desc := _get_annotation_description(arg_type):
|
||||
arg_descriptions[arg] = desc
|
||||
return description, arg_descriptions
|
||||
|
||||
@@ -274,6 +272,7 @@ def create_schema_from_function(
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic schema from a function's signature.
|
||||
|
||||
Args:
|
||||
model_name: Name to assign to the generated pydantic schema.
|
||||
func: Function to generate the schema from.
|
||||
@@ -417,11 +416,18 @@ class ChildTool(BaseTool):
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
schema = create_schema_from_function(self.name, self._run)
|
||||
return schema.schema()["properties"]
|
||||
return self.get_input_schema().schema()["properties"]
|
||||
|
||||
@property
|
||||
def tool_call_schema(self) -> Type[BaseModel]:
|
||||
full_schema = self.get_input_schema()
|
||||
fields = []
|
||||
for name, type_ in full_schema.__annotations__.items():
|
||||
if not _is_injected_arg_type(type_):
|
||||
fields.append(name)
|
||||
return _create_subset_model(
|
||||
self.name, full_schema, fields, fn_description=self.description
|
||||
)
|
||||
|
||||
# --- Runnable ---
|
||||
|
||||
@@ -1034,9 +1040,20 @@ class StructuredTool(BaseTool):
|
||||
else:
|
||||
raise ValueError("Function and/or coroutine must be provided")
|
||||
name = name or source_function.__name__
|
||||
description_ = description or source_function.__doc__
|
||||
if args_schema is None and infer_schema:
|
||||
# schema name is appended within function
|
||||
args_schema = create_schema_from_function(
|
||||
name,
|
||||
source_function,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
filter_args=_filter_schema_args(source_function),
|
||||
)
|
||||
description_ = description
|
||||
if description is None and not parse_docstring:
|
||||
description_ = source_function.__doc__ or None
|
||||
if description_ is None and args_schema:
|
||||
description_ = args_schema.__doc__
|
||||
description_ = args_schema.__doc__ or None
|
||||
if description_ is None:
|
||||
raise ValueError(
|
||||
"Function must have a docstring if description not provided."
|
||||
@@ -1048,29 +1065,11 @@ class StructuredTool(BaseTool):
|
||||
# Description example:
|
||||
# search_api(query: str) - Searches the API for the query.
|
||||
description_ = f"{description_.strip()}"
|
||||
_args_schema = args_schema
|
||||
if _args_schema is None and infer_schema:
|
||||
if config_param := _get_runnable_config_param(source_function):
|
||||
filter_args: Tuple[str, ...] = (
|
||||
config_param,
|
||||
"run_manager",
|
||||
"callbacks",
|
||||
)
|
||||
else:
|
||||
filter_args = ("run_manager", "callbacks")
|
||||
# schema name is appended within function
|
||||
_args_schema = create_schema_from_function(
|
||||
name,
|
||||
source_function,
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
filter_args=filter_args,
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
args_schema=_args_schema, # type: ignore[arg-type]
|
||||
args_schema=args_schema, # type: ignore[arg-type]
|
||||
description=description_,
|
||||
return_direct=return_direct,
|
||||
response_format=response_format,
|
||||
@@ -1624,15 +1623,40 @@ def convert_runnable_to_tool(
|
||||
)
|
||||
|
||||
|
||||
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]:
|
||||
if isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
try:
|
||||
type_hints = get_type_hints(func)
|
||||
return get_type_hints(func)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
for name, type_ in type_hints.items():
|
||||
if type_ is RunnableConfig:
|
||||
return name
|
||||
|
||||
|
||||
def _get_runnable_config_param(func: Callable) -> Optional[str]:
|
||||
type_hints = _get_type_hints(func)
|
||||
if not type_hints:
|
||||
return None
|
||||
for name, type_ in type_hints.items():
|
||||
if type_ is RunnableConfig:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
class InjectedToolArg:
|
||||
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
|
||||
|
||||
|
||||
def _is_injected_arg_type(type_: Type) -> bool:
|
||||
return any(
|
||||
isinstance(arg, InjectedToolArg)
|
||||
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
|
||||
for arg in get_args(type_)[1:]
|
||||
)
|
||||
|
||||
|
||||
def _filter_schema_args(func: Callable) -> List[str]:
|
||||
filter_args = list(FILTERED_ARGS)
|
||||
if config_param := _get_runnable_config_param(func):
|
||||
filter_args.append(config_param)
|
||||
# filter_args.extend(_get_non_model_params(type_hints))
|
||||
return filter_args
|
||||
|
@@ -196,9 +196,9 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
Returns:
|
||||
The function description.
|
||||
"""
|
||||
if tool.args_schema:
|
||||
if tool.tool_call_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
tool.tool_call_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
return {
|
||||
|
@@ -26,6 +26,7 @@ from langchain_core.runnables import (
|
||||
)
|
||||
from langchain_core.tools import (
|
||||
BaseTool,
|
||||
InjectedToolArg,
|
||||
SchemaAnnotationError,
|
||||
StructuredTool,
|
||||
Tool,
|
||||
@@ -33,6 +34,7 @@ from langchain_core.tools import (
|
||||
_create_subset_model,
|
||||
tool,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
|
||||
|
||||
@@ -1284,3 +1286,134 @@ def test_convert_from_runnable_other() -> None:
|
||||
as_tool = runnable.as_tool()
|
||||
result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}})
|
||||
assert result == "ba"
|
||||
|
||||
|
||||
@tool("foo", parse_docstring=True)
|
||||
def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str:
|
||||
"""foo.
|
||||
|
||||
Args:
|
||||
x: abc
|
||||
y: 123
|
||||
"""
|
||||
return y
|
||||
|
||||
|
||||
class InjectedTool(BaseTool):
|
||||
name: str = "foo"
|
||||
description: str = "foo."
|
||||
|
||||
def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any:
|
||||
"""foo.
|
||||
|
||||
Args:
|
||||
x: abc
|
||||
y: 123
|
||||
"""
|
||||
return y
|
||||
|
||||
|
||||
class fooSchema(BaseModel):
|
||||
"""foo."""
|
||||
|
||||
x: int = Field(..., description="abc")
|
||||
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
|
||||
..., description="123"
|
||||
)
|
||||
|
||||
|
||||
class InjectedToolWithSchema(BaseTool):
|
||||
name: str = "foo"
|
||||
description: str = "foo."
|
||||
args_schema: Type[BaseModel] = fooSchema
|
||||
|
||||
def _run(self, x: int, y: str) -> Any:
|
||||
return y
|
||||
|
||||
|
||||
@tool("foo", args_schema=fooSchema)
|
||||
def injected_tool_with_schema(x: int, y: str) -> str:
|
||||
return y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_", [InjectedTool()])
|
||||
def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
|
||||
assert tool_.get_input_schema().schema() == {
|
||||
"title": "fooSchema",
|
||||
"description": "foo.\n\nArgs:\n x: abc\n y: 123",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"title": "X", "type": "integer"},
|
||||
"y": {"title": "Y", "type": "string"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
}
|
||||
assert tool_.tool_call_schema.schema() == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {"x": {"title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
|
||||
assert tool_.invoke(
|
||||
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||
expected_error = (
|
||||
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
||||
)
|
||||
with pytest.raises(expected_error):
|
||||
tool_.invoke({"x": 5})
|
||||
|
||||
assert convert_to_openai_function(tool_) == {
|
||||
"name": "foo",
|
||||
"description": "foo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_",
|
||||
[injected_tool, injected_tool_with_schema, InjectedToolWithSchema()],
|
||||
)
|
||||
def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
||||
assert tool_.get_input_schema().schema() == {
|
||||
"title": "fooSchema",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"description": "abc", "title": "X", "type": "integer"},
|
||||
"y": {"description": "123", "title": "Y", "type": "string"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
}
|
||||
assert tool_.tool_call_schema.schema() == {
|
||||
"title": "foo",
|
||||
"description": "foo.",
|
||||
"type": "object",
|
||||
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
}
|
||||
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
|
||||
assert tool_.invoke(
|
||||
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
|
||||
) == ToolMessage("bar", tool_call_id="123", name="foo")
|
||||
expected_error = (
|
||||
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
|
||||
)
|
||||
with pytest.raises(expected_error):
|
||||
tool_.invoke({"x": 5})
|
||||
|
||||
assert convert_to_openai_function(tool_) == {
|
||||
"name": "foo",
|
||||
"description": "foo.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "abc"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
|
Reference in New Issue
Block a user