core[patch]: add InjectedToolArg annotation (#24279)

```python
from typing_extensions import Annotated
from langchain_core.tools import tool, InjectedToolArg
from langchain_anthropic import ChatAnthropic

@tool
def multiply(x: int, y: int, not_for_model: Annotated[dict, InjectedToolArg]) -> str:
    """multiply."""
    return x * y 

ChatAnthropic(model='claude-3-sonnet-20240229',).bind_tools([multiply]).invoke('5 times 3').tool_calls
'''
-> [{'name': 'multiply',
  'args': {'x': 5, 'y': 3},
  'id': 'toolu_01Y1QazYWhu4R8vF4hF4z9no',
  'type': 'tool_call'}]
'''
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
William FH
2024-07-17 15:28:40 -07:00
committed by GitHub
parent 80f3d48195
commit c5a07e2dd8
5 changed files with 660 additions and 136 deletions

View File

@@ -101,14 +101,12 @@ def _is_annotated_type(typ: Type[Any]) -> bool:
return get_origin(typ) is Annotated
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
def _get_annotation_description(arg_type: Type) -> str | None:
if _is_annotated_type(arg_type):
annotated_args = get_args(arg_type)
arg_type = annotated_args[0]
if len(annotated_args) > 1:
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
return annotation
for annotation in annotated_args[1:]:
if isinstance(annotation, str):
return annotation
return None
@@ -244,7 +242,7 @@ def _infer_arg_descriptions(
for arg, arg_type in annotations.items():
if arg in arg_descriptions:
continue
if desc := _get_annotation_description(arg, arg_type):
if desc := _get_annotation_description(arg_type):
arg_descriptions[arg] = desc
return description, arg_descriptions
@@ -274,6 +272,7 @@ def create_schema_from_function(
error_on_invalid_docstring: bool = False,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydantic schema.
func: Function to generate the schema from.
@@ -417,11 +416,18 @@ class ChildTool(BaseTool):
@property
def args(self) -> dict:
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
schema = create_schema_from_function(self.name, self._run)
return schema.schema()["properties"]
return self.get_input_schema().schema()["properties"]
@property
def tool_call_schema(self) -> Type[BaseModel]:
full_schema = self.get_input_schema()
fields = []
for name, type_ in full_schema.__annotations__.items():
if not _is_injected_arg_type(type_):
fields.append(name)
return _create_subset_model(
self.name, full_schema, fields, fn_description=self.description
)
# --- Runnable ---
@@ -1034,9 +1040,20 @@ class StructuredTool(BaseTool):
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description_ = description or source_function.__doc__
if args_schema is None and infer_schema:
# schema name is appended within function
args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=_filter_schema_args(source_function),
)
description_ = description
if description is None and not parse_docstring:
description_ = source_function.__doc__ or None
if description_ is None and args_schema:
description_ = args_schema.__doc__
description_ = args_schema.__doc__ or None
if description_ is None:
raise ValueError(
"Function must have a docstring if description not provided."
@@ -1048,29 +1065,11 @@ class StructuredTool(BaseTool):
# Description example:
# search_api(query: str) - Searches the API for the query.
description_ = f"{description_.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
if config_param := _get_runnable_config_param(source_function):
filter_args: Tuple[str, ...] = (
config_param,
"run_manager",
"callbacks",
)
else:
filter_args = ("run_manager", "callbacks")
# schema name is appended within function
_args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=filter_args,
)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema, # type: ignore[arg-type]
args_schema=args_schema, # type: ignore[arg-type]
description=description_,
return_direct=return_direct,
response_format=response_format,
@@ -1624,15 +1623,40 @@ def convert_runnable_to_tool(
)
def _get_runnable_config_param(func: Callable) -> Optional[str]:
def _get_type_hints(func: Callable) -> Optional[Dict[str, Type]]:
if isinstance(func, functools.partial):
func = func.func
try:
type_hints = get_type_hints(func)
return get_type_hints(func)
except Exception:
return None
else:
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
def _get_runnable_config_param(func: Callable) -> Optional[str]:
type_hints = _get_type_hints(func)
if not type_hints:
return None
for name, type_ in type_hints.items():
if type_ is RunnableConfig:
return name
return None
class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""
def _is_injected_arg_type(type_: Type) -> bool:
return any(
isinstance(arg, InjectedToolArg)
or (isinstance(arg, type) and issubclass(arg, InjectedToolArg))
for arg in get_args(type_)[1:]
)
def _filter_schema_args(func: Callable) -> List[str]:
filter_args = list(FILTERED_ARGS)
if config_param := _get_runnable_config_param(func):
filter_args.append(config_param)
# filter_args.extend(_get_non_model_params(type_hints))
return filter_args

View File

@@ -196,9 +196,9 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
Returns:
The function description.
"""
if tool.args_schema:
if tool.tool_call_schema:
return convert_pydantic_to_openai_function(
tool.args_schema, name=tool.name, description=tool.description
tool.tool_call_schema, name=tool.name, description=tool.description
)
else:
return {

View File

@@ -26,6 +26,7 @@ from langchain_core.runnables import (
)
from langchain_core.tools import (
BaseTool,
InjectedToolArg,
SchemaAnnotationError,
StructuredTool,
Tool,
@@ -33,6 +34,7 @@ from langchain_core.tools import (
_create_subset_model,
tool,
)
from langchain_core.utils.function_calling import convert_to_openai_function
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
@@ -1284,3 +1286,134 @@ def test_convert_from_runnable_other() -> None:
as_tool = runnable.as_tool()
result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}})
assert result == "ba"
@tool("foo", parse_docstring=True)
def injected_tool(x: int, y: Annotated[str, InjectedToolArg]) -> str:
"""foo.
Args:
x: abc
y: 123
"""
return y
class InjectedTool(BaseTool):
name: str = "foo"
description: str = "foo."
def _run(self, x: int, y: Annotated[str, InjectedToolArg]) -> Any:
"""foo.
Args:
x: abc
y: 123
"""
return y
class fooSchema(BaseModel):
"""foo."""
x: int = Field(..., description="abc")
y: Annotated[str, "foobar comment", InjectedToolArg()] = Field(
..., description="123"
)
class InjectedToolWithSchema(BaseTool):
name: str = "foo"
description: str = "foo."
args_schema: Type[BaseModel] = fooSchema
def _run(self, x: int, y: str) -> Any:
return y
@tool("foo", args_schema=fooSchema)
def injected_tool_with_schema(x: int, y: str) -> str:
return y
@pytest.mark.parametrize("tool_", [InjectedTool()])
def test_tool_injected_arg_without_schema(tool_: BaseTool) -> None:
assert tool_.get_input_schema().schema() == {
"title": "fooSchema",
"description": "foo.\n\nArgs:\n x: abc\n y: 123",
"type": "object",
"properties": {
"x": {"title": "X", "type": "integer"},
"y": {"title": "Y", "type": "string"},
},
"required": ["x", "y"],
}
assert tool_.tool_call_schema.schema() == {
"title": "foo",
"description": "foo.",
"type": "object",
"properties": {"x": {"title": "X", "type": "integer"}},
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
"name": "foo",
"description": "foo.",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer"}},
"required": ["x"],
},
}
@pytest.mark.parametrize(
"tool_",
[injected_tool, injected_tool_with_schema, InjectedToolWithSchema()],
)
def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
assert tool_.get_input_schema().schema() == {
"title": "fooSchema",
"description": "foo.",
"type": "object",
"properties": {
"x": {"description": "abc", "title": "X", "type": "integer"},
"y": {"description": "123", "title": "Y", "type": "string"},
},
"required": ["x", "y"],
}
assert tool_.tool_call_schema.schema() == {
"title": "foo",
"description": "foo.",
"type": "object",
"properties": {"x": {"description": "abc", "title": "X", "type": "integer"}},
"required": ["x"],
}
assert tool_.invoke({"x": 5, "y": "bar"}) == "bar"
assert tool_.invoke(
{"name": "foo", "args": {"x": 5, "y": "bar"}, "id": "123", "type": "tool_call"}
) == ToolMessage("bar", tool_call_id="123", name="foo")
expected_error = (
ValidationError if not isinstance(tool_, InjectedTool) else TypeError
)
with pytest.raises(expected_error):
tool_.invoke({"x": 5})
assert convert_to_openai_function(tool_) == {
"name": "foo",
"description": "foo.",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "abc"}},
"required": ["x"],
},
}