mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
core[patch]: Respect injected in bound fns (#24733)
Since right now you cant use the nice injected arg syntas directly with model.bind_tools()
This commit is contained in:
parent
7fcfe7c1f4
commit
01ab2918a2
@ -120,6 +120,7 @@ def _get_filtered_args(
|
|||||||
func: Callable,
|
func: Callable,
|
||||||
*,
|
*,
|
||||||
filter_args: Sequence[str],
|
filter_args: Sequence[str],
|
||||||
|
include_injected: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Get the arguments from a function's signature."""
|
"""Get the arguments from a function's signature."""
|
||||||
schema = inferred_model.schema()["properties"]
|
schema = inferred_model.schema()["properties"]
|
||||||
@ -127,7 +128,9 @@ def _get_filtered_args(
|
|||||||
return {
|
return {
|
||||||
k: schema[k]
|
k: schema[k]
|
||||||
for i, (k, param) in enumerate(valid_keys.items())
|
for i, (k, param) in enumerate(valid_keys.items())
|
||||||
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
|
if k not in filter_args
|
||||||
|
and (i > 0 or param.name not in ("self", "cls"))
|
||||||
|
and (include_injected or not _is_injected_arg_type(param.annotation))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -247,6 +250,7 @@ def create_schema_from_function(
|
|||||||
filter_args: Optional[Sequence[str]] = None,
|
filter_args: Optional[Sequence[str]] = None,
|
||||||
parse_docstring: bool = False,
|
parse_docstring: bool = False,
|
||||||
error_on_invalid_docstring: bool = False,
|
error_on_invalid_docstring: bool = False,
|
||||||
|
include_injected: bool = True,
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
"""Create a pydantic schema from a function's signature.
|
"""Create a pydantic schema from a function's signature.
|
||||||
|
|
||||||
@ -260,6 +264,9 @@ def create_schema_from_function(
|
|||||||
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
|
||||||
whether to raise ValueError on invalid Google Style docstrings.
|
whether to raise ValueError on invalid Google Style docstrings.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
|
include_injected: Whether to include injected arguments in the schema.
|
||||||
|
Defaults to True, since we want to include them in the schema
|
||||||
|
when *validating* tool inputs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A pydantic model with the same arguments as the function.
|
A pydantic model with the same arguments as the function.
|
||||||
@ -277,7 +284,9 @@ def create_schema_from_function(
|
|||||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||||
)
|
)
|
||||||
# Pydantic adds placeholder virtual fields we need to strip
|
# Pydantic adds placeholder virtual fields we need to strip
|
||||||
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
|
valid_properties = _get_filtered_args(
|
||||||
|
inferred_model, func, filter_args=filter_args, include_injected=include_injected
|
||||||
|
)
|
||||||
return _create_subset_model(
|
return _create_subset_model(
|
||||||
f"{model_name}Schema",
|
f"{model_name}Schema",
|
||||||
inferred_model,
|
inferred_model,
|
||||||
|
@ -179,6 +179,7 @@ def convert_python_function_to_openai_function(
|
|||||||
filter_args=(),
|
filter_args=(),
|
||||||
parse_docstring=True,
|
parse_docstring=True,
|
||||||
error_on_invalid_docstring=False,
|
error_on_invalid_docstring=False,
|
||||||
|
include_injected=False,
|
||||||
)
|
)
|
||||||
return convert_pydantic_to_openai_function(
|
return convert_pydantic_to_openai_function(
|
||||||
model,
|
model,
|
||||||
|
@ -1429,6 +1429,36 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parametrized_tools() -> list:
|
||||||
|
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
||||||
|
"""my_tool."""
|
||||||
|
return some_tool
|
||||||
|
|
||||||
|
async def my_async_tool(
|
||||||
|
x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg]
|
||||||
|
) -> str:
|
||||||
|
"""my_tool."""
|
||||||
|
return some_tool
|
||||||
|
|
||||||
|
return [my_tool, my_async_tool]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tool_", _get_parametrized_tools())
|
||||||
|
def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
||||||
|
assert convert_to_openai_function(tool_) == {
|
||||||
|
"name": tool_.__name__,
|
||||||
|
"description": "my_tool.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"x": {"type": "integer"},
|
||||||
|
"y": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["x", "y"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_models() -> List[Any]:
|
def generate_models() -> List[Any]:
|
||||||
"""Generate a list of base models depending on the pydantic version."""
|
"""Generate a list of base models depending on the pydantic version."""
|
||||||
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user