Compare commits

...

2 Commits

Author SHA1 Message Date
Eugene Yurtsev
9c787ceaa5 x 2023-08-15 13:53:33 -04:00
Eugene Yurtsev
7efcd2e67a x 2023-08-15 13:27:05 -04:00
2 changed files with 35 additions and 6 deletions

View File

@@ -22,8 +22,15 @@ from pydantic_v1 import (
# used raising an exception.
try:
from pydantic.v1.main import ModelMetaclass
from pydantic.v1 import BaseModel as V1_BASE_MODEL
from pydantic import BaseModel as V2_BASE_MODEL
BaseModelCompatible = Union[V1_BASE_MODEL, V2_BASE_MODEL]
except ImportError:
from pydantic.main import ModelMetaclass
from pydantic import BaseModel
BaseModelCompatible = BaseModel
from langchain.callbacks.base import BaseCallbackManager
@@ -106,7 +113,7 @@ class _SchemaConfig:
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
) -> Type[BaseModel]: # Uses pydantic v1
"""Create a pydantic schema from a function's signature.
Args:
model_name: Name to assign to the generated pydandic schema
@@ -150,7 +157,7 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
args_schema: Optional[Type[BaseModelCompatible]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
@@ -555,7 +562,7 @@ class Tool(BaseTool):
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[Type[BaseModelCompatible]] = None,
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
@@ -573,7 +580,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
args_schema: Type[BaseModelCompatible] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., Any]
"""The function to run when the tool is called."""
@@ -650,7 +657,7 @@ class StructuredTool(BaseTool):
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[Type[BaseModelCompatible]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
@@ -705,7 +712,7 @@ class StructuredTool(BaseTool):
def tool(
*args: Union[str, Callable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[Type[BaseModelCompatible]] = None,
infer_schema: bool = True,
) -> Callable:
"""Make tools out of functions, can be used with or without arguments.

View File

@@ -0,0 +1,22 @@
"""Testing that declaring custom tools using pydantic v2 works."""
from langchain import _PYDANTIC_MAJOR_VERSION
from langchain.tools.base import tool
import pytest
if _PYDANTIC_MAJOR_VERSION != 2:
pytest.skip(
"Unit tests for testing compatibility with pydantic major version 2",
allow_module_level=True,
)
def test_custom_tool_pydantic_v2() -> None:
"""Test that custom tools can be declared using pydantic v2."""
@tool()
def speak(what: str) -> str:
"""Return what was said backwards."""
return what[::-1]
assert speak("hello") == "olleh"