core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)

After this PR chat models will correctly handle pydantic 2 with
bind_tools and with_structured_output.


```python
import pydantic
print(pydantic.__version__)
```
2.8.2

```python
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}]
<class '__main__.Add'>
```


```python
from langchain_openai import ChatOpenAI
from pydantic.v1 import BaseModel, Field

class Add(BaseModel):
    x: int
    y: int

model = ChatOpenAI().bind_tools([Add])
print(model.invoke('2 + 5').tool_calls)

model = ChatOpenAI().with_structured_output(Add)
print(type(model.invoke('2 + 5')))
```

```python
[{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}]
<class '__main__.Add'>
```

Addresses issues: https://github.com/langchain-ai/langchain/issues/22782

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Bagatur
2024-07-22 13:07:39 -07:00
committed by GitHub
parent 199e64d372
commit 236e957abb
26 changed files with 185 additions and 59 deletions

View File

@@ -1,20 +1,58 @@
"""Unit tests for chat models."""
from abc import ABC, abstractmethod
from typing import Any, List, Literal, Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class Person(BaseModel):
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class PersonA(BaseModelProper):
"""Record attributes of a person."""
name: str = FieldProper(..., description="The name of the person.")
age: int = FieldProper(..., description="The age of the person.")
return PersonA
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
if PYDANTIC_MAJOR_VERSION == 2:
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
@tool
def my_adder_tool(a: int, b: int) -> int:
"""Takes two integers, a and b, and returns their sum."""
@@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
if not self.has_tool_calling:
return
tool_model = model.bind_tools(
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
)
tools = [my_adder_tool, my_adder]
for pydantic_model in TEST_PYDANTIC_MODELS:
tools.extend([pydantic_model, pydantic_model.schema()])
# Doing a mypy ignore here since some of the tools are from pydantic
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
# so type checking does not become annoying to users.
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
assert isinstance(tool_model, RunnableBinding)
@pytest.mark.parametrize("schema", [Person, Person.schema()])
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
def test_with_structured_output(
self,
model: BaseChatModel,
@@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str

View File

@@ -0,0 +1,14 @@
"""Utilities for working with pydantic models."""
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
try:
import pydantic
return int(pydantic.__version__.split(".")[0])
except ImportError:
return 0
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()