mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
standard-tests[major]: use pydantic v2
This commit is contained in:
@@ -16,10 +16,10 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel as RawBaseModel
|
||||
from pydantic import Field as RawField
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.v1 import BaseModel as BaseModelV1
|
||||
from pydantic.v1 import Field as FieldV1
|
||||
|
||||
from langchain_standard_tests.unit_tests.chat_models import (
|
||||
ChatModelTests,
|
||||
@@ -28,8 +28,8 @@ from langchain_standard_tests.unit_tests.chat_models import (
|
||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
|
||||
|
||||
class MagicFunctionSchema(RawBaseModel):
|
||||
input: int = RawField(..., gt=-1000, lt=1000)
|
||||
class MagicFunctionSchema(BaseModel):
|
||||
input: int = Field(..., gt=-1000, lt=1000)
|
||||
|
||||
|
||||
@tool(args_schema=MagicFunctionSchema)
|
||||
@@ -240,14 +240,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class Joke(BaseModelProper):
|
||||
class Joke(BaseModel):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = FieldProper(description="question to set up a joke")
|
||||
punchline: str = FieldProper(description="answer to resolve the joke")
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
# Pydantic class
|
||||
# Type ignoring since the interface only officially supports pydantic 1
|
||||
@@ -280,11 +277,11 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.has_tool_calling:
|
||||
pytest.skip("Test requires tool calling.")
|
||||
|
||||
class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel
|
||||
class Joke(BaseModelV1): # Uses langchain_core.pydantic_v1.BaseModel
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
setup: str = FieldV1(description="question to set up a joke")
|
||||
punchline: str = FieldV1(description="answer to resolve the joke")
|
||||
|
||||
# Pydantic class
|
||||
chat = model.with_structured_output(Joke)
|
||||
@@ -439,7 +436,7 @@ class ChatModelIntegrationTests(ChatModelTests):
|
||||
if not self.supports_anthropic_inputs:
|
||||
return
|
||||
|
||||
class color_picker(BaseModel):
|
||||
class color_picker(BaseModelV1):
|
||||
"""Input your fav color and get a random fact about it."""
|
||||
|
||||
fav_color: str
|
||||
|
@@ -6,9 +6,18 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic.v1 import (
|
||||
BaseModel as BaseModelV1,
|
||||
)
|
||||
from pydantic.v1 import (
|
||||
Field as FieldV1,
|
||||
)
|
||||
from pydantic.v1 import (
|
||||
ValidationError as ValidationErrorV1,
|
||||
)
|
||||
|
||||
from langchain_standard_tests.base import BaseStandardTests
|
||||
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
@@ -25,27 +34,24 @@ def generate_schema_pydantic_v1_from_2() -> Any:
|
||||
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||
if PYDANTIC_MAJOR_VERSION != 2:
|
||||
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
class PersonB(BaseModel):
|
||||
class PersonB(BaseModelV1):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
name: str = FieldV1(..., description="The name of the person.")
|
||||
age: int = FieldV1(..., description="The age of the person.")
|
||||
|
||||
return PersonB
|
||||
|
||||
|
||||
def generate_schema_pydantic() -> Any:
|
||||
"""Works with either pydantic 1 or 2"""
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class PersonA(BaseModelProper):
|
||||
class PersonA(BaseModel):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = FieldProper(..., description="The name of the person.")
|
||||
age: int = FieldProper(..., description="The age of the person.")
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
return PersonA
|
||||
|
||||
@@ -199,9 +205,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
assert model.with_structured_output(schema) is not None
|
||||
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
class ExpectedParams(BaseModel):
|
||||
class ExpectedParams(BaseModelV1):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
@@ -212,7 +216,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
except ValidationErrorV1 as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
# Test optional params
|
||||
@@ -222,5 +226,5 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
except ValidationErrorV1 as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
@@ -5,7 +5,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_standard_tests.base import BaseStandardTests
|
||||
|
||||
|
Reference in New Issue
Block a user