mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 13:07:58 +00:00
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com> Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com> Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com> Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com> Co-authored-by: ZhangShenao <15201440436@163.com> Co-authored-by: Friso H. Kingma <fhkingma@gmail.com> Co-authored-by: ChengZi <chen.zhang@zilliz.com> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Morgante Pell <morgantep@google.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Unit tests for chat models."""
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Literal, Optional, Tuple, Type
|
||||
@@ -7,9 +8,18 @@ from unittest import mock
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from langchain_core.runnables import RunnableBinding
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic.v1 import (
|
||||
BaseModel as BaseModelV1,
|
||||
)
|
||||
from pydantic.v1 import (
|
||||
Field as FieldV1,
|
||||
)
|
||||
from pydantic.v1 import (
|
||||
ValidationError as ValidationErrorV1,
|
||||
)
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_standard_tests.base import BaseStandardTests
|
||||
@@ -27,27 +37,24 @@ def generate_schema_pydantic_v1_from_2() -> Any:
|
||||
"""Use to generate a schema from v1 namespace in pydantic 2."""
|
||||
if PYDANTIC_MAJOR_VERSION != 2:
|
||||
raise AssertionError("This function is only compatible with Pydantic v2.")
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
class PersonB(BaseModel):
|
||||
class PersonB(BaseModelV1):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
name: str = FieldV1(..., description="The name of the person.")
|
||||
age: int = FieldV1(..., description="The age of the person.")
|
||||
|
||||
return PersonB
|
||||
|
||||
|
||||
def generate_schema_pydantic() -> Any:
|
||||
"""Works with either pydantic 1 or 2"""
|
||||
from pydantic import BaseModel as BaseModelProper
|
||||
from pydantic import Field as FieldProper
|
||||
|
||||
class PersonA(BaseModelProper):
|
||||
class PersonA(BaseModel):
|
||||
"""Record attributes of a person."""
|
||||
|
||||
name: str = FieldProper(..., description="The name of the person.")
|
||||
age: int = FieldProper(..., description="The age of the person.")
|
||||
name: str = Field(..., description="The name of the person.")
|
||||
age: int = Field(..., description="The age of the person.")
|
||||
|
||||
return PersonA
|
||||
|
||||
@@ -181,7 +188,12 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
tools = [my_adder_tool, my_adder]
|
||||
|
||||
for pydantic_model in TEST_PYDANTIC_MODELS:
|
||||
tools.extend([pydantic_model, pydantic_model.schema()])
|
||||
model_schema = (
|
||||
pydantic_model.model_json_schema()
|
||||
if hasattr(pydantic_model, "model_json_schema")
|
||||
else pydantic_model.schema()
|
||||
)
|
||||
tools.extend([pydantic_model, model_schema])
|
||||
|
||||
# Doing a mypy ignore here since some of the tools are from pydantic
|
||||
# BaseModel 2 which isn't typed properly yet. This will need to be fixed
|
||||
@@ -201,9 +213,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
assert model.with_structured_output(schema) is not None
|
||||
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
class ExpectedParams(BaseModel):
|
||||
class ExpectedParams(BaseModelV1):
|
||||
ls_provider: str
|
||||
ls_model_name: str
|
||||
ls_model_type: Literal["chat"]
|
||||
@@ -214,7 +224,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
except ValidationErrorV1 as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
# Test optional params
|
||||
@@ -224,7 +234,7 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
ls_params = model._get_ls_params()
|
||||
try:
|
||||
ExpectedParams(**ls_params)
|
||||
except ValidationError as e:
|
||||
except ValidationErrorV1 as e:
|
||||
pytest.fail(f"Validation error: {e}")
|
||||
|
||||
def test_serdes(self, model: BaseChatModel, snapshot: SnapshotAssertion) -> None:
|
||||
|
@@ -5,7 +5,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pydantic import SecretStr
|
||||
|
||||
from langchain_standard_tests.base import BaseStandardTests
|
||||
|
||||
|
Reference in New Issue
Block a user