mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)
After this PR chat models will correctly handle pydantic 2 with bind_tools and with_structured_output. ```python import pydantic print(pydantic.__version__) ``` 2.8.2 ```python from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ``` [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}] <class '__main__.Add'> ``` ```python from langchain_openai import ChatOpenAI from pydantic.v1 import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ```python [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}] <class '__main__.Add'> ``` Addresses issues: https://github.com/langchain-ai/langchain/issues/22782 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
199e64d372
commit
236e957abb
@ -24,6 +24,7 @@ from langchain_core.output_parsers import (
|
|||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
from langchain_community.output_parsers.ernie_functions import (
|
from langchain_community.output_parsers.ernie_functions import (
|
||||||
JsonOutputFunctionsParser,
|
JsonOutputFunctionsParser,
|
||||||
@ -94,7 +95,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
|
|||||||
for arg, arg_type in annotations.items():
|
for arg, arg_type in annotations.items():
|
||||||
if arg == "return":
|
if arg == "return":
|
||||||
continue
|
continue
|
||||||
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
|
if isinstance(arg_type, type) and is_basemodel_subclass(arg_type):
|
||||||
# Mypy error:
|
# Mypy error:
|
||||||
# "type" has no attribute "schema"
|
# "type" has no attribute "schema"
|
||||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
||||||
@ -156,7 +157,7 @@ def convert_to_ernie_function(
|
|||||||
"""
|
"""
|
||||||
if isinstance(function, dict):
|
if isinstance(function, dict):
|
||||||
return function
|
return function
|
||||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||||
return cast(Dict, convert_pydantic_to_ernie_function(function))
|
return cast(Dict, convert_pydantic_to_ernie_function(function))
|
||||||
elif callable(function):
|
elif callable(function):
|
||||||
return convert_python_function_to_ernie_function(function)
|
return convert_python_function_to_ernie_function(function)
|
||||||
@ -185,7 +186,7 @@ def get_ernie_output_parser(
|
|||||||
only the function arguments and not the function name.
|
only the function arguments and not the function name.
|
||||||
"""
|
"""
|
||||||
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
|
function_names = [convert_to_ernie_function(f)["name"] for f in functions]
|
||||||
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
|
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||||
if len(functions) > 1:
|
if len(functions) > 1:
|
||||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||||
name: fn for name, fn in zip(function_names, functions)
|
name: fn for name, fn in zip(function_names, functions)
|
||||||
|
@ -40,11 +40,17 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -769,7 +775,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
llm = self.bind_tools([schema])
|
llm = self.bind_tools([schema])
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
|
@ -57,6 +57,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
from langchain_community.utilities.requests import Requests
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
@ -443,7 +444,7 @@ class ChatEdenAI(BaseChatModel):
|
|||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
llm = self.bind_tools([schema], tool_choice="required")
|
llm = self.bind_tools([schema], tool_choice="required")
|
||||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], first_tool_only=True
|
tools=[schema], first_tool_only=True
|
||||||
)
|
)
|
||||||
|
@ -46,10 +46,15 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
|
|
||||||
class ChatLlamaCpp(BaseChatModel):
|
class ChatLlamaCpp(BaseChatModel):
|
||||||
@ -525,7 +530,7 @@ class ChatLlamaCpp(BaseChatModel):
|
|||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
|
@ -53,11 +53,16 @@ from langchain_core.outputs import (
|
|||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
before_sleep_log,
|
before_sleep_log,
|
||||||
@ -865,7 +870,7 @@ class ChatTongyi(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
llm = self.bind_tools([schema])
|
llm = self.bind_tools([schema])
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
|
@ -55,11 +55,16 @@ from langchain_core.outputs import (
|
|||||||
RunInfo,
|
RunInfo,
|
||||||
)
|
)
|
||||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
@ -1162,7 +1167,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
"with_structured_output is not implemented for this model."
|
"with_structured_output is not implemented for this model."
|
||||||
)
|
)
|
||||||
llm = self.bind_tools([schema], tool_choice="any")
|
llm = self.bind_tools([schema], tool_choice="any")
|
||||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], first_tool_only=True
|
tools=[schema], first_tool_only=True
|
||||||
)
|
)
|
||||||
|
@ -82,6 +82,7 @@ from langchain_core.runnables.utils import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
from langchain_core.utils.aiter import aclosing, atee, py_anext
|
||||||
from langchain_core.utils.iter import safetee
|
from langchain_core.utils.iter import safetee
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
@ -300,7 +301,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
root_type = self.InputType
|
root_type = self.InputType
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model(
|
return create_model(
|
||||||
@ -332,7 +333,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
root_type = self.OutputType
|
root_type = self.OutputType
|
||||||
|
|
||||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
|
||||||
return root_type
|
return root_type
|
||||||
|
|
||||||
return create_model(
|
return create_model(
|
||||||
|
@ -22,6 +22,7 @@ from typing import (
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.runnables.base import Runnable as RunnableType
|
from langchain_core.runnables.base import Runnable as RunnableType
|
||||||
@ -229,7 +230,7 @@ def node_data_json(
|
|||||||
"name": node_data_str(node.id, node.data),
|
"name": node_data_str(node.id, node.data),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
|
elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
|
||||||
json = (
|
json = (
|
||||||
{
|
{
|
||||||
"type": "schema",
|
"type": "schema",
|
||||||
|
@ -28,6 +28,7 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
@ -100,7 +101,11 @@ def convert_pydantic_to_openai_function(
|
|||||||
Returns:
|
Returns:
|
||||||
The function description.
|
The function description.
|
||||||
"""
|
"""
|
||||||
schema = dereference_refs(model.schema())
|
if hasattr(model, "model_json_schema"):
|
||||||
|
schema = model.model_json_schema() # Pydantic 2
|
||||||
|
else:
|
||||||
|
schema = model.schema() # Pydantic 1
|
||||||
|
schema = dereference_refs(schema)
|
||||||
schema.pop("definitions", None)
|
schema.pop("definitions", None)
|
||||||
title = schema.pop("title", "")
|
title = schema.pop("title", "")
|
||||||
default_description = schema.pop("description", "")
|
default_description = schema.pop("description", "")
|
||||||
@ -272,7 +277,7 @@ def convert_to_openai_function(
|
|||||||
"description": function.pop("description"),
|
"description": function.pop("description"),
|
||||||
"parameters": function,
|
"parameters": function,
|
||||||
}
|
}
|
||||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
elif isinstance(function, type) and is_basemodel_subclass(function):
|
||||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||||
elif isinstance(function, BaseTool):
|
elif isinstance(function, BaseTool):
|
||||||
return cast(Dict, format_tool_to_openai_function(function))
|
return cast(Dict, format_tool_to_openai_function(function))
|
||||||
|
@ -8,12 +8,13 @@ from langchain_core.load.load import loads
|
|||||||
from langchain_core.prompts.structured import StructuredPrompt
|
from langchain_core.prompts.structured import StructuredPrompt
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables.base import Runnable, RunnableLambda
|
from langchain_core.runnables.base import Runnable, RunnableLambda
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
|
|
||||||
def _fake_runnable(
|
def _fake_runnable(
|
||||||
schema: Union[Dict, Type[BaseModel]], _: Any
|
schema: Union[Dict, Type[BaseModel]], _: Any
|
||||||
) -> Union[BaseModel, Dict]:
|
) -> Union[BaseModel, Dict]:
|
||||||
if isclass(schema) and issubclass(schema, BaseModel):
|
if isclass(schema) and is_basemodel_subclass(schema):
|
||||||
return schema(name="yo", value=42)
|
return schema(name="yo", value=42)
|
||||||
else:
|
else:
|
||||||
params = cast(Dict, schema)["parameters"]
|
params = cast(Dict, schema)["parameters"]
|
||||||
|
@ -34,11 +34,14 @@ from langchain_core.output_parsers.json import JsonOutputParser
|
|||||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from langchain_core.prompts import SystemMessagePromptTemplate
|
from langchain_core.prompts import SystemMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda
|
from langchain_core.runnables import Runnable, RunnableLambda
|
||||||
from langchain_core.runnables.base import RunnableMap
|
from langchain_core.runnables.base import RunnableMap
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
|
||||||
|
|
||||||
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
||||||
|
|
||||||
@ -75,14 +78,10 @@ _DictOrPydantic = Union[Dict, _BM]
|
|||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and (
|
return isinstance(obj, type) and (
|
||||||
issubclass(obj, BaseModel) or BaseModel in obj.__bases__
|
is_basemodel_subclass(obj) or BaseModel in obj.__bases__
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_pydantic_object(obj: Any) -> bool:
|
|
||||||
return isinstance(obj, BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_ollama_tool(tool: Any) -> Dict:
|
def convert_to_ollama_tool(tool: Any) -> Dict:
|
||||||
"""Convert a tool to an Ollama tool."""
|
"""Convert a tool to an Ollama tool."""
|
||||||
description = None
|
description = None
|
||||||
@ -93,7 +92,7 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
|
|||||||
schema = tool.tool_call_schema.schema()
|
schema = tool.tool_call_schema.schema()
|
||||||
name = tool.get_name()
|
name = tool.get_name()
|
||||||
description = tool.description
|
description = tool.description
|
||||||
elif _is_pydantic_object(tool):
|
elif is_basemodel_instance(tool):
|
||||||
schema = tool.get_input_schema().schema()
|
schema = tool.get_input_schema().schema()
|
||||||
name = tool.get_name()
|
name = tool.get_name()
|
||||||
description = tool.description
|
description = tool.description
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
from langchain_core.prompts.few_shot import FewShotPromptTemplate
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_instance
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGenerator(BaseModel):
|
class SyntheticDataGenerator(BaseModel):
|
||||||
@ -63,8 +64,10 @@ class SyntheticDataGenerator(BaseModel):
|
|||||||
"""Prevents duplicates by adding previously generated examples to the few shot
|
"""Prevents duplicates by adding previously generated examples to the few shot
|
||||||
list."""
|
list."""
|
||||||
if self.template and self.template.examples:
|
if self.template and self.template.examples:
|
||||||
if isinstance(example, BaseModel):
|
if is_basemodel_instance(example):
|
||||||
formatted_example = self._format_dict_to_string(example.dict())
|
formatted_example = self._format_dict_to_string(
|
||||||
|
cast(BaseModel, example).dict()
|
||||||
|
)
|
||||||
elif isinstance(example, dict):
|
elif isinstance(example, dict):
|
||||||
formatted_example = self._format_dict_to_string(example)
|
formatted_example = self._format_dict_to_string(example)
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any, List, Optional, Type, Union
|
from typing import Any, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
@ -10,6 +10,7 @@ from langchain_core.output_parsers.openai_functions import (
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
||||||
@ -45,7 +46,7 @@ def create_qa_with_structure_chain(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if output_parser == "pydantic":
|
if output_parser == "pydantic":
|
||||||
if not (isinstance(schema, type) and issubclass(schema, BaseModel)):
|
if not (isinstance(schema, type) and is_basemodel_subclass(schema)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must provide a pydantic class for schema when output_parser is "
|
"Must provide a pydantic class for schema when output_parser is "
|
||||||
"'pydantic'."
|
"'pydantic'."
|
||||||
@ -60,10 +61,10 @@ def create_qa_with_structure_chain(
|
|||||||
f"Got unexpected output_parser: {output_parser}. "
|
f"Got unexpected output_parser: {output_parser}. "
|
||||||
f"Should be one of `pydantic` or `base`."
|
f"Should be one of `pydantic` or `base`."
|
||||||
)
|
)
|
||||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
schema_dict = schema.schema()
|
schema_dict = cast(dict, schema.schema())
|
||||||
else:
|
else:
|
||||||
schema_dict = schema
|
schema_dict = cast(dict, schema)
|
||||||
function = {
|
function = {
|
||||||
"name": schema_dict["title"],
|
"name": schema_dict["title"],
|
||||||
"description": schema_dict["description"],
|
"description": schema_dict["description"],
|
||||||
|
@ -24,6 +24,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -465,7 +466,7 @@ def _get_openai_tool_output_parser(
|
|||||||
*,
|
*,
|
||||||
first_tool_only: bool = False,
|
first_tool_only: bool = False,
|
||||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||||
if isinstance(tool, type) and issubclass(tool, BaseModel):
|
if isinstance(tool, type) and is_basemodel_subclass(tool):
|
||||||
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
|
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
|
||||||
PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only)
|
PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only)
|
||||||
)
|
)
|
||||||
@ -493,7 +494,7 @@ def get_openai_output_parser(
|
|||||||
not a Pydantic class, then the output parser will automatically extract
|
not a Pydantic class, then the output parser will automatically extract
|
||||||
only the function arguments and not the function name.
|
only the function arguments and not the function name.
|
||||||
"""
|
"""
|
||||||
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel):
|
if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
|
||||||
if len(functions) > 1:
|
if len(functions) > 1:
|
||||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||||
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
convert_to_openai_function(fn)["name"]: fn for fn in functions
|
||||||
@ -516,7 +517,7 @@ def _create_openai_json_runnable(
|
|||||||
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
""""""
|
""""""
|
||||||
if isinstance(output_schema, type) and issubclass(output_schema, BaseModel):
|
if isinstance(output_schema, type) and is_basemodel_subclass(output_schema):
|
||||||
output_parser = output_parser or PydanticOutputParser(
|
output_parser = output_parser or PydanticOutputParser(
|
||||||
pydantic_object=output_schema, # type: ignore
|
pydantic_object=output_schema, # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -50,7 +50,12 @@ from langchain_core.output_parsers import (
|
|||||||
)
|
)
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import OutputParserLike
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
@ -63,6 +68,7 @@ from langchain_core.utils import (
|
|||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
)
|
)
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
from langchain_anthropic.output_parsers import extract_tool_calls
|
from langchain_anthropic.output_parsers import extract_tool_calls
|
||||||
|
|
||||||
@ -994,7 +1000,7 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
|
|
||||||
tool_name = convert_to_anthropic_tool(schema)["name"]
|
tool_name = convert_to_anthropic_tool(schema)["name"]
|
||||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
llm = self.bind_tools([schema], tool_choice=tool_name)
|
||||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], first_tool_only=True
|
tools=[schema], first_tool_only=True
|
||||||
)
|
)
|
||||||
|
@ -69,7 +69,12 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
@ -81,6 +86,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -938,7 +944,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
|
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
@ -66,7 +66,12 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
@ -78,6 +83,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
|
|
||||||
class ChatGroq(BaseChatModel):
|
class ChatGroq(BaseChatModel):
|
||||||
@ -1053,7 +1059,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
class _FunctionCall(TypedDict):
|
class _FunctionCall(TypedDict):
|
||||||
|
@ -388,7 +388,7 @@ def test_json_mode_structured_output() -> None:
|
|||||||
result = chat.invoke(
|
result = chat.invoke(
|
||||||
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
|
||||||
)
|
)
|
||||||
assert type(result) == Joke
|
assert type(result) is Joke
|
||||||
assert len(result.setup) != 0
|
assert len(result.setup) != 0
|
||||||
assert len(result.punchline) != 0
|
assert len(result.punchline) != 0
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
|
|||||||
):
|
):
|
||||||
res = llm.invoke("bar")
|
res = llm.invoke("bar")
|
||||||
assert res.content == "Bar Baz"
|
assert res.content == "Bar Baz"
|
||||||
assert type(res) == AIMessage
|
assert type(res) is AIMessage
|
||||||
assert completed
|
assert completed
|
||||||
|
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ async def test_groq_ainvoke(mock_completion: dict) -> None:
|
|||||||
):
|
):
|
||||||
res = await llm.ainvoke("bar")
|
res = await llm.ainvoke("bar")
|
||||||
assert res.content == "Bar Baz"
|
assert res.content == "Bar Baz"
|
||||||
assert type(res) == AIMessage
|
assert type(res) is AIMessage
|
||||||
assert completed
|
assert completed
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,11 +63,17 @@ from langchain_core.output_parsers.openai_tools import (
|
|||||||
parse_tool_call,
|
parse_tool_call,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
SecretStr,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -779,7 +785,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||||
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
if method == "function_calling":
|
if method == "function_calling":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -36,6 +36,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
|
||||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ class _AllReturnType(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
class AzureChatOpenAI(BaseChatOpenAI):
|
class AzureChatOpenAI(BaseChatOpenAI):
|
||||||
|
@ -86,6 +86,7 @@ from langchain_core.utils.function_calling import (
|
|||||||
convert_to_openai_function,
|
convert_to_openai_function,
|
||||||
convert_to_openai_tool,
|
convert_to_openai_tool,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
from langchain_core.utils.utils import build_extra_kwargs
|
from langchain_core.utils.utils import build_extra_kwargs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -1765,7 +1766,7 @@ class ChatOpenAI(BaseChatOpenAI):
|
|||||||
|
|
||||||
|
|
||||||
def _is_pydantic_class(obj: Any) -> bool:
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Set, Tuple, Union
|
from typing import Any, Dict, List, Set, Tuple, Union, cast
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
from langchain_core.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
create_model,
|
||||||
|
)
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_instance
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -160,8 +165,8 @@ def get_param_fields(endpoint_spec: dict) -> dict:
|
|||||||
def model_to_dict(
|
def model_to_dict(
|
||||||
item: Union[BaseModel, List, Dict[str, Any]],
|
item: Union[BaseModel, List, Dict[str, Any]],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
if isinstance(item, BaseModel):
|
if is_basemodel_instance(item):
|
||||||
return item.dict()
|
return cast(BaseModel, item).dict()
|
||||||
elif isinstance(item, dict):
|
elif isinstance(item, dict):
|
||||||
return {key: model_to_dict(value) for key, value in item.items()}
|
return {key: model_to_dict(value) for key, value in item.items()}
|
||||||
elif isinstance(item, list):
|
elif isinstance(item, list):
|
||||||
|
@ -1,20 +1,58 @@
|
|||||||
|
"""Unit tests for chat models."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Literal, Optional, Type
|
from typing import Any, List, Literal, Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.runnables import RunnableBinding
|
from langchain_core.runnables import RunnableBinding
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
class Person(BaseModel):
|
|
||||||
|
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
|
||||||
"""Record attributes of a person."""
|
"""Record attributes of a person."""
|
||||||
|
|
||||||
name: str = Field(..., description="The name of the person.")
|
name: str = Field(..., description="The name of the person.")
|
||||||
age: int = Field(..., description="The age of the person.")
|
age: int = Field(..., description="The age of the person.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_schema_pydantic_v1_from_2() -> Any:
|
||||||
|
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||||
|
if PYDANTIC_MAJOR_VERSION != 2:
|
||||||
|
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||||
|
from pydantic.v1 import BaseModel, Field
|
||||||
|
|
||||||
|
class PersonB(BaseModel):
|
||||||
|
"""Record attributes of a person."""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the person.")
|
||||||
|
age: int = Field(..., description="The age of the person.")
|
||||||
|
|
||||||
|
return PersonB
|
||||||
|
|
||||||
|
|
||||||
|
def generate_schema_pydantic() -> Any:
|
||||||
|
"""Works with either pydantic 1 or 2"""
|
||||||
|
from pydantic import BaseModel as BaseModelProper
|
||||||
|
from pydantic import Field as FieldProper
|
||||||
|
|
||||||
|
class PersonA(BaseModelProper):
|
||||||
|
"""Record attributes of a person."""
|
||||||
|
|
||||||
|
name: str = FieldProper(..., description="The name of the person.")
|
||||||
|
age: int = FieldProper(..., description="The age of the person.")
|
||||||
|
|
||||||
|
return PersonA
|
||||||
|
|
||||||
|
|
||||||
|
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
|
||||||
|
|
||||||
|
if PYDANTIC_MAJOR_VERSION == 2:
|
||||||
|
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def my_adder_tool(a: int, b: int) -> int:
|
def my_adder_tool(a: int, b: int) -> int:
|
||||||
"""Takes two integers, a and b, and returns their sum."""
|
"""Takes two integers, a and b, and returns their sum."""
|
||||||
@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
if not self.has_tool_calling:
|
if not self.has_tool_calling:
|
||||||
return
|
return
|
||||||
|
|
||||||
tool_model = model.bind_tools(
|
tools = [my_adder_tool, my_adder]
|
||||||
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
|
|
||||||
)
|
for pydantic_model in TEST_PYDANTIC_MODELS:
|
||||||
|
tools.extend([pydantic_model, pydantic_model.schema()])
|
||||||
|
|
||||||
|
# Doing a mypy ignore here since some of the tools are from pydantic
|
||||||
|
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
|
||||||
|
# so type checking does not become annoying to users.
|
||||||
|
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
|
||||||
assert isinstance(tool_model, RunnableBinding)
|
assert isinstance(tool_model, RunnableBinding)
|
||||||
|
|
||||||
@pytest.mark.parametrize("schema", [Person, Person.schema()])
|
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
|
||||||
def test_with_structured_output(
|
def test_with_structured_output(
|
||||||
self,
|
self,
|
||||||
model: BaseChatModel,
|
model: BaseChatModel,
|
||||||
@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
assert model.with_structured_output(schema) is not None
|
assert model.with_structured_output(schema) is not None
|
||||||
|
|
||||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||||
|
|
||||||
class ExpectedParams(BaseModel):
|
class ExpectedParams(BaseModel):
|
||||||
ls_provider: str
|
ls_provider: str
|
||||||
ls_model_name: str
|
ls_model_name: str
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
"""Utilities for working with pydantic models."""
|
||||||
|
|
||||||
|
|
||||||
|
def get_pydantic_major_version() -> int:
|
||||||
|
"""Get the major version of Pydantic."""
|
||||||
|
try:
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
return int(pydantic.__version__.split(".")[0])
|
||||||
|
except ImportError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
Loading…
Reference in New Issue
Block a user