standard-tests[major]: use pydantic v2

This commit is contained in:
Bagatur
2024-09-03 17:48:20 -07:00
parent 9a9ab65030
commit 02f87203f7
3 changed files with 32 additions and 31 deletions

View File

@@ -16,10 +16,10 @@ from langchain_core.messages import (
) )
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel as RawBaseModel from pydantic import BaseModel, Field
from pydantic import Field as RawField from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1
from langchain_standard_tests.unit_tests.chat_models import ( from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests, ChatModelTests,
@@ -28,8 +28,8 @@ from langchain_standard_tests.unit_tests.chat_models import (
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
class MagicFunctionSchema(RawBaseModel): class MagicFunctionSchema(BaseModel):
input: int = RawField(..., gt=-1000, lt=1000) input: int = Field(..., gt=-1000, lt=1000)
@tool(args_schema=MagicFunctionSchema) @tool(args_schema=MagicFunctionSchema)
@@ -240,14 +240,11 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
from pydantic import BaseModel as BaseModelProper class Joke(BaseModel):
from pydantic import Field as FieldProper
class Joke(BaseModelProper):
"""Joke to tell user.""" """Joke to tell user."""
setup: str = FieldProper(description="question to set up a joke") setup: str = Field(description="question to set up a joke")
punchline: str = FieldProper(description="answer to resolve the joke") punchline: str = Field(description="answer to resolve the joke")
# Pydantic class # Pydantic class
# Type ignoring since the interface only officially supports pydantic 1 # Type ignoring since the interface only officially supports pydantic 1
@@ -280,11 +277,11 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.has_tool_calling: if not self.has_tool_calling:
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel class Joke(BaseModelV1): # Uses langchain_core.pydantic_v1.BaseModel
"""Joke to tell user.""" """Joke to tell user."""
setup: str = Field(description="question to set up a joke") setup: str = FieldV1(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke") punchline: str = FieldV1(description="answer to resolve the joke")
# Pydantic class # Pydantic class
chat = model.with_structured_output(Joke) chat = model.with_structured_output(Joke)
@@ -439,7 +436,7 @@ class ChatModelIntegrationTests(ChatModelTests):
if not self.supports_anthropic_inputs: if not self.supports_anthropic_inputs:
return return
class color_picker(BaseModel): class color_picker(BaseModelV1):
"""Input your fav color and get a random fact about it.""" """Input your fav color and get a random fact about it."""
fav_color: str fav_color: str

View File

@@ -6,9 +6,18 @@ from unittest import mock
import pytest import pytest
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.runnables import RunnableBinding from langchain_core.runnables import RunnableBinding
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import BaseModel, Field, SecretStr
from pydantic.v1 import (
BaseModel as BaseModelV1,
)
from pydantic.v1 import (
Field as FieldV1,
)
from pydantic.v1 import (
ValidationError as ValidationErrorV1,
)
from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.base import BaseStandardTests
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
@@ -25,27 +34,24 @@ def generate_schema_pydantic_v1_from_2() -> Any:
"""Use to generate a schema from v1 namespace in pydantic 2.""" """Use to generate a schema from v1 namespace in pydantic 2."""
if PYDANTIC_MAJOR_VERSION != 2: if PYDANTIC_MAJOR_VERSION != 2:
raise AssertionError("This function is only compatible with Pydantic v2.") raise AssertionError("This function is only compatible with Pydantic v2.")
from pydantic.v1 import BaseModel, Field
class PersonB(BaseModel): class PersonB(BaseModelV1):
"""Record attributes of a person.""" """Record attributes of a person."""
name: str = Field(..., description="The name of the person.") name: str = FieldV1(..., description="The name of the person.")
age: int = Field(..., description="The age of the person.") age: int = FieldV1(..., description="The age of the person.")
return PersonB return PersonB
def generate_schema_pydantic() -> Any: def generate_schema_pydantic() -> Any:
"""Works with either pydantic 1 or 2""" """Works with either pydantic 1 or 2"""
from pydantic import BaseModel as BaseModelProper
from pydantic import Field as FieldProper
class PersonA(BaseModelProper): class PersonA(BaseModel):
"""Record attributes of a person.""" """Record attributes of a person."""
name: str = FieldProper(..., description="The name of the person.") name: str = Field(..., description="The name of the person.")
age: int = FieldProper(..., description="The age of the person.") age: int = Field(..., description="The age of the person.")
return PersonA return PersonA
@@ -199,9 +205,7 @@ class ChatModelUnitTests(ChatModelTests):
assert model.with_structured_output(schema) is not None assert model.with_structured_output(schema) is not None
def test_standard_params(self, model: BaseChatModel) -> None: def test_standard_params(self, model: BaseChatModel) -> None:
from langchain_core.pydantic_v1 import BaseModel, ValidationError class ExpectedParams(BaseModelV1):
class ExpectedParams(BaseModel):
ls_provider: str ls_provider: str
ls_model_name: str ls_model_name: str
ls_model_type: Literal["chat"] ls_model_type: Literal["chat"]
@@ -212,7 +216,7 @@ class ChatModelUnitTests(ChatModelTests):
ls_params = model._get_ls_params() ls_params = model._get_ls_params()
try: try:
ExpectedParams(**ls_params) ExpectedParams(**ls_params)
except ValidationError as e: except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}") pytest.fail(f"Validation error: {e}")
# Test optional params # Test optional params
@@ -222,5 +226,5 @@ class ChatModelUnitTests(ChatModelTests):
ls_params = model._get_ls_params() ls_params = model._get_ls_params()
try: try:
ExpectedParams(**ls_params) ExpectedParams(**ls_params)
except ValidationError as e: except ValidationErrorV1 as e:
pytest.fail(f"Validation error: {e}") pytest.fail(f"Validation error: {e}")

View File

@@ -5,7 +5,7 @@ from unittest import mock
import pytest import pytest
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import SecretStr from pydantic import SecretStr
from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.base import BaseStandardTests