core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)

After this PR chat models will correctly handle pydantic 2 with
bind_tools and with_structured_output.


```python
import pydantic
print(pydantic.__version__)
```
2.8.2

```python
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}]
<class '__main__.Add'>
```


```python
from langchain_openai import ChatOpenAI
from pydantic.v1 import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```python
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}]
<class '__main__.Add'>
```

Addresses issues: https://github.com/langchain-ai/langchain/issues/22782

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Bagatur 2024-07-22 13:07:39 -07:00 committed by GitHub
parent 199e64d372
commit 236e957abb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 185 additions and 59 deletions

View File

@ -24,6 +24,7 @@ from langchain_core.output_parsers import (
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable from langchain_core.runnables import Runnable
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_community.output_parsers.ernie_functions import ( from langchain_community.output_parsers.ernie_functions import (
JsonOutputFunctionsParser, JsonOutputFunctionsParser,
@ -94,7 +95,7 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
for arg, arg_type in annotations.items(): for arg, arg_type in annotations.items():
if arg == "return": if arg == "return":
continue continue
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel): if isinstance(arg_type, type) and is_basemodel_subclass(arg_type):
# Mypy error: # Mypy error:
# "type" has no attribute "schema" # "type" has no attribute "schema"
properties[arg] = arg_type.schema() # type: ignore[attr-defined] properties[arg] = arg_type.schema() # type: ignore[attr-defined]
@ -156,7 +157,7 @@ def convert_to_ernie_function(
""" """
if isinstance(function, dict): if isinstance(function, dict):
return function return function
elif isinstance(function, type) and issubclass(function, BaseModel): elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_ernie_function(function)) return cast(Dict, convert_pydantic_to_ernie_function(function))
elif callable(function): elif callable(function):
return convert_python_function_to_ernie_function(function) return convert_python_function_to_ernie_function(function)
@ -185,7 +186,7 @@ def get_ernie_output_parser(
only the function arguments and not the function name. only the function arguments and not the function name.
""" """
function_names = [convert_to_ernie_function(f)["name"] for f in functions] function_names = [convert_to_ernie_function(f)["name"] for f in functions]
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1: if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = { pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions) name: fn for name, fn in zip(function_names, functions)

View File

@ -40,11 +40,17 @@ from langchain_core.output_parsers.openai_tools import (
PydanticToolsParser, PydanticToolsParser,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -769,7 +775,7 @@ class QianfanChatEndpoint(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
llm = self.bind_tools([schema]) llm = self.bind_tools([schema])
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(

View File

@ -57,6 +57,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_community.utilities.requests import Requests from langchain_community.utilities.requests import Requests
@ -443,7 +444,7 @@ class ChatEdenAI(BaseChatModel):
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
llm = self.bind_tools([schema], tool_choice="required") llm = self.bind_tools([schema], tool_choice="required")
if isinstance(schema, type) and issubclass(schema, BaseModel): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True
) )

View File

@ -46,10 +46,15 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
class ChatLlamaCpp(BaseChatModel): class ChatLlamaCpp(BaseChatModel):
@ -525,7 +530,7 @@ class ChatLlamaCpp(BaseChatModel):
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if schema is None: if schema is None:
raise ValueError( raise ValueError(
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "

View File

@ -53,11 +53,16 @@ from langchain_core.outputs import (
ChatGenerationChunk, ChatGenerationChunk,
ChatResult, ChatResult,
) )
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env, pre_init
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from tenacity import ( from tenacity import (
before_sleep_log, before_sleep_log,
@ -865,7 +870,7 @@ class ChatTongyi(BaseChatModel):
""" """
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
llm = self.bind_tools([schema]) llm = self.bind_tools([schema])
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(

View File

@ -55,11 +55,16 @@ from langchain_core.outputs import (
RunInfo, RunInfo,
) )
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
root_validator,
)
from langchain_core.runnables import RunnableMap, RunnablePassthrough from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
@ -1162,7 +1167,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"with_structured_output is not implemented for this model." "with_structured_output is not implemented for this model."
) )
llm = self.bind_tools([schema], tool_choice="any") llm = self.bind_tools([schema], tool_choice="any")
if isinstance(schema, type) and issubclass(schema, BaseModel): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True
) )

View File

@ -82,6 +82,7 @@ from langchain_core.runnables.utils import (
) )
from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
@ -300,7 +301,7 @@ class Runnable(Generic[Input, Output], ABC):
""" """
root_type = self.InputType root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
return root_type return root_type
return create_model( return create_model(
@ -332,7 +333,7 @@ class Runnable(Generic[Input, Output], ABC):
""" """
root_type = self.OutputType root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel): if inspect.isclass(root_type) and is_basemodel_subclass(root_type):
return root_type return root_type
return create_model( return create_model(

View File

@ -22,6 +22,7 @@ from typing import (
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.runnables.base import Runnable as RunnableType from langchain_core.runnables.base import Runnable as RunnableType
@ -229,7 +230,7 @@ def node_data_json(
"name": node_data_str(node.id, node.data), "name": node_data_str(node.id, node.data),
}, },
} }
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel): elif inspect.isclass(node.data) and is_basemodel_subclass(node.data):
json = ( json = (
{ {
"type": "schema", "type": "schema",

View File

@ -28,6 +28,7 @@ from langchain_core.messages import (
) )
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -100,7 +101,11 @@ def convert_pydantic_to_openai_function(
Returns: Returns:
The function description. The function description.
""" """
schema = dereference_refs(model.schema()) if hasattr(model, "model_json_schema"):
schema = model.model_json_schema() # Pydantic 2
else:
schema = model.schema() # Pydantic 1
schema = dereference_refs(schema)
schema.pop("definitions", None) schema.pop("definitions", None)
title = schema.pop("title", "") title = schema.pop("title", "")
default_description = schema.pop("description", "") default_description = schema.pop("description", "")
@ -272,7 +277,7 @@ def convert_to_openai_function(
"description": function.pop("description"), "description": function.pop("description"),
"parameters": function, "parameters": function,
} }
elif isinstance(function, type) and issubclass(function, BaseModel): elif isinstance(function, type) and is_basemodel_subclass(function):
return cast(Dict, convert_pydantic_to_openai_function(function)) return cast(Dict, convert_pydantic_to_openai_function(function))
elif isinstance(function, BaseTool): elif isinstance(function, BaseTool):
return cast(Dict, format_tool_to_openai_function(function)) return cast(Dict, format_tool_to_openai_function(function))

View File

@ -8,12 +8,13 @@ from langchain_core.load.load import loads
from langchain_core.prompts.structured import StructuredPrompt from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableLambda
from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable( def _fake_runnable(
schema: Union[Dict, Type[BaseModel]], _: Any schema: Union[Dict, Type[BaseModel]], _: Any
) -> Union[BaseModel, Dict]: ) -> Union[BaseModel, Dict]:
if isclass(schema) and issubclass(schema, BaseModel): if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=42) return schema(name="yo", value=42)
else: else:
params = cast(Dict, schema)["parameters"] params = cast(Dict, schema)["parameters"]

View File

@ -34,11 +34,14 @@ from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import (
BaseModel,
)
from langchain_core.runnables import Runnable, RunnableLambda from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.runnables.base import RunnableMap from langchain_core.runnables.base import RunnableMap
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.pydantic import is_basemodel_instance, is_basemodel_subclass
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools: DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
@ -75,14 +78,10 @@ _DictOrPydantic = Union[Dict, _BM]
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and ( return isinstance(obj, type) and (
issubclass(obj, BaseModel) or BaseModel in obj.__bases__ is_basemodel_subclass(obj) or BaseModel in obj.__bases__
) )
def _is_pydantic_object(obj: Any) -> bool:
return isinstance(obj, BaseModel)
def convert_to_ollama_tool(tool: Any) -> Dict: def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool.""" """Convert a tool to an Ollama tool."""
description = None description = None
@ -93,7 +92,7 @@ def convert_to_ollama_tool(tool: Any) -> Dict:
schema = tool.tool_call_schema.schema() schema = tool.tool_call_schema.schema()
name = tool.get_name() name = tool.get_name()
description = tool.description description = tool.description
elif _is_pydantic_object(tool): elif is_basemodel_instance(tool):
schema = tool.get_input_schema().schema() schema = tool.get_input_schema().schema()
name = tool.get_name() name = tool.get_name()
description = tool.description description = tool.description

View File

@ -1,11 +1,12 @@
import asyncio import asyncio
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union, cast
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.pydantic_v1 import BaseModel, root_validator from langchain.pydantic_v1 import BaseModel, root_validator
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.utils.pydantic import is_basemodel_instance
class SyntheticDataGenerator(BaseModel): class SyntheticDataGenerator(BaseModel):
@ -63,8 +64,10 @@ class SyntheticDataGenerator(BaseModel):
"""Prevents duplicates by adding previously generated examples to the few shot """Prevents duplicates by adding previously generated examples to the few shot
list.""" list."""
if self.template and self.template.examples: if self.template and self.template.examples:
if isinstance(example, BaseModel): if is_basemodel_instance(example):
formatted_example = self._format_dict_to_string(example.dict()) formatted_example = self._format_dict_to_string(
cast(BaseModel, example).dict()
)
elif isinstance(example, dict): elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example) formatted_example = self._format_dict_to_string(example)
else: else:

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union, cast
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
@ -10,6 +10,7 @@ from langchain_core.output_parsers.openai_functions import (
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs from langchain.chains.openai_functions.utils import get_llm_kwargs
@ -45,7 +46,7 @@ def create_qa_with_structure_chain(
""" """
if output_parser == "pydantic": if output_parser == "pydantic":
if not (isinstance(schema, type) and issubclass(schema, BaseModel)): if not (isinstance(schema, type) and is_basemodel_subclass(schema)):
raise ValueError( raise ValueError(
"Must provide a pydantic class for schema when output_parser is " "Must provide a pydantic class for schema when output_parser is "
"'pydantic'." "'pydantic'."
@ -60,10 +61,10 @@ def create_qa_with_structure_chain(
f"Got unexpected output_parser: {output_parser}. " f"Got unexpected output_parser: {output_parser}. "
f"Should be one of `pydantic` or `base`." f"Should be one of `pydantic` or `base`."
) )
if isinstance(schema, type) and issubclass(schema, BaseModel): if isinstance(schema, type) and is_basemodel_subclass(schema):
schema_dict = schema.schema() schema_dict = cast(dict, schema.schema())
else: else:
schema_dict = schema schema_dict = cast(dict, schema)
function = { function = {
"name": schema_dict["title"], "name": schema_dict["title"],
"description": schema_dict["description"], "description": schema_dict["description"],

View File

@ -24,6 +24,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
@deprecated( @deprecated(
@ -465,7 +466,7 @@ def _get_openai_tool_output_parser(
*, *,
first_tool_only: bool = False, first_tool_only: bool = False,
) -> Union[BaseOutputParser, BaseGenerationOutputParser]: ) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
if isinstance(tool, type) and issubclass(tool, BaseModel): if isinstance(tool, type) and is_basemodel_subclass(tool):
output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = ( output_parser: Union[BaseOutputParser, BaseGenerationOutputParser] = (
PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only) PydanticToolsParser(tools=[tool], first_tool_only=first_tool_only)
) )
@ -493,7 +494,7 @@ def get_openai_output_parser(
not a Pydantic class, then the output parser will automatically extract not a Pydantic class, then the output parser will automatically extract
only the function arguments and not the function name. only the function arguments and not the function name.
""" """
if isinstance(functions[0], type) and issubclass(functions[0], BaseModel): if isinstance(functions[0], type) and is_basemodel_subclass(functions[0]):
if len(functions) > 1: if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = { pydantic_schema: Union[Dict, Type[BaseModel]] = {
convert_to_openai_function(fn)["name"]: fn for fn in functions convert_to_openai_function(fn)["name"]: fn for fn in functions
@ -516,7 +517,7 @@ def _create_openai_json_runnable(
output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None, output_parser: Optional[Union[BaseOutputParser, BaseGenerationOutputParser]] = None,
) -> Runnable: ) -> Runnable:
"""""" """"""
if isinstance(output_schema, type) and issubclass(output_schema, BaseModel): if isinstance(output_schema, type) and is_basemodel_subclass(output_schema):
output_parser = output_parser or PydanticOutputParser( output_parser = output_parser or PydanticOutputParser(
pydantic_object=output_schema, # type: ignore pydantic_object=output_schema, # type: ignore
) )

View File

@ -50,7 +50,12 @@ from langchain_core.output_parsers import (
) )
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable, Runnable,
RunnableMap, RunnableMap,
@ -63,6 +68,7 @@ from langchain_core.utils import (
get_pydantic_field_names, get_pydantic_field_names,
) )
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_anthropic.output_parsers import extract_tool_calls from langchain_anthropic.output_parsers import extract_tool_calls
@ -994,7 +1000,7 @@ class ChatAnthropic(BaseChatModel):
tool_name = convert_to_anthropic_tool(schema)["name"] tool_name = convert_to_anthropic_tool(schema)["name"]
llm = self.bind_tools([schema], tool_choice=tool_name) llm = self.bind_tools([schema], tool_choice=tool_name)
if isinstance(schema, type) and issubclass(schema, BaseModel): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True
) )

View File

@ -69,7 +69,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
@ -81,6 +86,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -938,7 +944,7 @@ class ChatFireworks(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and is_basemodel_subclass(obj)
def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict: def _lc_tool_call_to_fireworks_tool_call(tool_call: ToolCall) -> dict:

View File

@ -66,7 +66,12 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
@ -78,6 +83,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
class ChatGroq(BaseChatModel): class ChatGroq(BaseChatModel):
@ -1053,7 +1059,7 @@ class ChatGroq(BaseChatModel):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and is_basemodel_subclass(obj)
class _FunctionCall(TypedDict): class _FunctionCall(TypedDict):

View File

@ -388,7 +388,7 @@ def test_json_mode_structured_output() -> None:
result = chat.invoke( result = chat.invoke(
"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys" "Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys"
) )
assert type(result) == Joke assert type(result) is Joke
assert len(result.setup) != 0 assert len(result.setup) != 0
assert len(result.punchline) != 0 assert len(result.punchline) != 0

View File

@ -173,7 +173,7 @@ def test_groq_invoke(mock_completion: dict) -> None:
): ):
res = llm.invoke("bar") res = llm.invoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert type(res) == AIMessage assert type(res) is AIMessage
assert completed assert completed
@ -195,7 +195,7 @@ async def test_groq_ainvoke(mock_completion: dict) -> None:
): ):
res = await llm.ainvoke("bar") res = await llm.ainvoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert type(res) == AIMessage assert type(res) is AIMessage
assert completed assert completed

View File

@ -63,11 +63,17 @@ from langchain_core.output_parsers.openai_tools import (
parse_tool_call, parse_tool_call,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import (
BaseModel,
Field,
SecretStr,
root_validator,
)
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -779,7 +785,7 @@ class ChatMistralAI(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
if kwargs: if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}") raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel) is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
if method == "function_calling": if method == "function_calling":
if schema is None: if schema is None:
raise ValueError( raise ValueError(

View File

@ -36,6 +36,7 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai.chat_models.base import BaseChatOpenAI from langchain_openai.chat_models.base import BaseChatOpenAI
@ -54,7 +55,7 @@ class _AllReturnType(TypedDict):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and is_basemodel_subclass(obj)
class AzureChatOpenAI(BaseChatOpenAI): class AzureChatOpenAI(BaseChatOpenAI):

View File

@ -86,6 +86,7 @@ from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
convert_to_openai_tool, convert_to_openai_tool,
) )
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_core.utils.utils import build_extra_kwargs from langchain_core.utils.utils import build_extra_kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1765,7 +1766,7 @@ class ChatOpenAI(BaseChatOpenAI):
def _is_pydantic_class(obj: Any) -> bool: def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel) return isinstance(obj, type) and is_basemodel_subclass(obj)
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:

View File

@ -1,9 +1,14 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Union from typing import Any, Dict, List, Set, Tuple, Union, cast
from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.pydantic_v1 import (
BaseModel,
Field,
create_model,
)
from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.json_schema import dereference_refs
from langchain_core.utils.pydantic import is_basemodel_instance
@dataclass(frozen=True) @dataclass(frozen=True)
@ -160,8 +165,8 @@ def get_param_fields(endpoint_spec: dict) -> dict:
def model_to_dict( def model_to_dict(
item: Union[BaseModel, List, Dict[str, Any]], item: Union[BaseModel, List, Dict[str, Any]],
) -> Any: ) -> Any:
if isinstance(item, BaseModel): if is_basemodel_instance(item):
return item.dict() return cast(BaseModel, item).dict()
elif isinstance(item, dict): elif isinstance(item, dict):
return {key: model_to_dict(value) for key, value in item.items()} return {key: model_to_dict(value) for key, value in item.items()}
elif isinstance(item, list): elif isinstance(item, list):

View File

@ -1,20 +1,58 @@
"""Unit tests for chat models."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Type from typing import Any, List, Literal, Optional, Type
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel):
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person.""" """Record attributes of a person."""
name: str = Field(..., description="The name of the person.") name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.") age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class PersonA(BaseModelProper):
"""Record attributes of a person."""
name: str = FieldProper(..., description="The name of the person.")
age: int = FieldProper(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool @tool
def my_adder_tool(a: int, b: int) -> int: def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum.""" """Takes two integers, a and b, and returns their sum."""
@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
if not self.has_tool_calling: if not self.has_tool_calling:
return return
tool_model = model.bind_tools( tools = [my_adder_tool, my_adder]
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
) for pydantic_model in TEST_PYDANTIC_MODELS:
tools.extend([pydantic_model, pydantic_model.schema()])
# Doing a mypy ignore here since some of the tools are from pydantic
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
# so type checking does not become annoying to users.
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
assert isinstance(tool_model, RunnableBinding) assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", [Person, Person.schema()]) @pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output( def test_with_structured_output(
self, self,
model: BaseChatModel, model: BaseChatModel,
@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None: def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class ExpectedParams(BaseModel): class ExpectedParams(BaseModel):
ls_provider: str ls_provider: str
ls_model_name: str ls_model_name: str

View File

@ -0,0 +1,14 @@
"""Utilities for working with pydantic models."""
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()