mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-02 11:39:18 +00:00
core,groq,openai,mistralai,robocorp,fireworks,anthropic[patch]: Update BaseModel subclass and instance checks to handle both v1 and proper namespaces (#24417)
After this PR chat models will correctly handle pydantic 2 with bind_tools and with_structured_output. ```python import pydantic print(pydantic.__version__) ``` 2.8.2 ```python from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ``` [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_PNUFa4pdfNOYXxIMHc6ps2Do', 'type': 'tool_call'}] <class '__main__.Add'> ``` ```python from langchain_openai import ChatOpenAI from pydantic.v1 import BaseModel, Field class Add(BaseModel): x: int y: int model = ChatOpenAI().bind_tools([Add]) print(model.invoke('2 + 5').tool_calls) model = ChatOpenAI().with_structured_output(Add) print(type(model.invoke('2 + 5'))) ``` ```python [{'name': 'Add', 'args': {'x': 2, 'y': 5}, 'id': 'call_hhiHYP441cp14TtrHKx3Upg0', 'type': 'tool_call'}] <class '__main__.Add'> ``` Addresses issues: https://github.com/langchain-ai/langchain/issues/22782 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -1,20 +1,58 @@
|
||||
"""Unit tests for chat models."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Literal, Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
class Person(BaseModel):
|
||||
|
||||
class Person(BaseModel): # Used by some dependent tests. Should be deprecated.
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
|
||||
def generate_schema_pydantic_v1_from_2() -> Any:
|
||||
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||
if PYDANTIC_MAJOR_VERSION != 2:
|
||||
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
class PersonB(BaseModel):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
return PersonB
|
||||
|
||||
|
||||
def generate_schema_pydantic() -> Any:
|
||||
"""Works with either pydantic 1 or 2"""
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class PersonA(BaseModelProper):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = FieldProper(..., description="The name of the person.")
|
||||
age: int = FieldProper(..., description="The age of the person.")
|
||||
|
||||
return PersonA
|
||||
|
||||
|
||||
TEST_PYDANTIC_MODELS = [generate_schema_pydantic()]
|
||||
|
||||
if PYDANTIC_MAJOR_VERSION == 2:
|
||||
TEST_PYDANTIC_MODELS.append(generate_schema_pydantic_v1_from_2())
|
||||
|
||||
|
||||
@tool
|
||||
def my_adder_tool(a: int, b: int) -> int:
|
||||
"""Takes two integers, a and b, and returns their sum."""
|
||||
@@ -112,12 +150,18 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
return
|
||||
|
||||
tool_model = model.bind_tools(
|
||||
[Person, Person.schema(), my_adder_tool, my_adder], tool_choice="any"
|
||||
)
|
||||
tools = [my_adder_tool, my_adder]
|
||||
|
||||
for pydantic_model in TEST_PYDANTIC_MODELS:
|
||||
tools.extend([pydantic_model, pydantic_model.schema()])
|
||||
|
||||
# Doing a mypy ignore here since some of the tools are from pydantic
|
||||
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
|
||||
# so type checking does not become annoying to users.
|
||||
tool_model = model.bind_tools(tools, tool_choice="any") # type: ignore
|
||||
assert isinstance(tool_model, RunnableBinding)
|
||||
|
||||
@pytest.mark.parametrize("schema", [Person, Person.schema()])
|
||||
@pytest.mark.parametrize("schema", TEST_PYDANTIC_MODELS)
|
||||
def test_with_structured_output(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
@@ -129,6 +173,8 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
assert model.with_structured_output(schema) is not None
|
||||
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
class ExpectedParams(BaseModel):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
|
@@ -0,0 +1,14 @@
|
||||
"""Utilities for working with pydantic models."""
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""Get the major version of Pydantic."""
|
||||
try:
|
||||
import pydantic
|
||||
|
||||
return int(pydantic.__version__.split(".")[0])
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
|
||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
Reference in New Issue
Block a user