standard-tests[major]: use pydantic v2

This commit is contained in:
Bagatur
2024-09-03 17:48:20 -07:00
parent 9a9ab65030
commit 02f87203f7
3 changed files with 32 additions and 31 deletions

View File

@@ -16,10 +16,10 @@ from langchain_core.messages import (
)
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from pydantic import BaseModel as RawBaseModel
from pydantic import Field as RawField
from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1
from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests,
@@ -28,8 +28,8 @@ from langchain_standard_tests.unit_tests.chat_models import (
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class MagicFunctionSchema(RawBaseModel):
input: int = RawField(..., gt=-1000, lt=1000)
class MagicFunctionSchema(BaseModel):
input: int = Field(..., gt=-1000, lt=1000)
@tool(args_schema=MagicFunctionSchema)
@@ -240,14 +240,11 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class Joke(BaseModelProper):
class Joke(BaseModel):
"""Joke to tell user."""
setup: str = FieldProper(description="question to set up a joke")
punchline: str = FieldProper(description="answer to resolve the joke")
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
# Pydantic class
# Type ignoring since the interface only officially supports pydantic 1
@@ -280,11 +277,11 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel
class Joke(BaseModelV1): # Uses langchain_core.pydantic_v1.BaseModel
"""Joke to tell user."""
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
setup: str = FieldV1(description="question to set up a joke")
punchline: str = FieldV1(description="answer to resolve the joke")
# Pydantic class
chat = model.with_structured_output(Joke)
@@ -439,7 +436,7 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.supports_anthropic_inputs:
return
class color_picker(BaseModel):
class color_picker(BaseModelV1):
"""Input your fav color and get a random fact about it."""
fav_color: str

View File

@@ -6,9 +6,18 @@ from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool
from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import (
BaseModel as BaseModelV1,
)
from pydantic.v1 import (
Field as FieldV1,
)
from pydantic.v1 import (
ValidationError as ValidationErrorV1,
)
from langchain_standard_tests.base import BaseStandardTests
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
@@ -25,27 +34,24 @@ def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel):
class PersonB(BaseModelV1):
"""Record attributes of a person."""
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
name: str = FieldV1(..., description="The name of the person.")
age: int = FieldV1(..., description="The age of the person.")
return PersonB
def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2"""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class PersonA(BaseModelProper):
class PersonA(BaseModel):
"""Record attributes of a person."""
name: str = FieldProper(..., description="The name of the person.")
age: int = FieldProper(..., description="The age of the person.")
name: str = Field(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.")
return PersonA
@@ -199,9 +205,7 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError
class ExpectedParams(BaseModel):
class ExpectedParams(BaseModelV1):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
@@ -212,7 +216,7 @@ class ChatModelUnitTests(ChatModelTests):
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")
# Test optional params
@@ -222,5 +226,5 @@ class ChatModelUnitTests(ChatModelTests):
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}")

View File

@@ -5,7 +5,7 @@ from unittest import mock
import pytest
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr
from pydantic import SecretStr
from langchain_standard_tests.base import BaseStandardTests