mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-10 15:06:18 +00:00
Make tests stricter, remove old code, fix up pydantic import when using v2 (#11231)
Make tests stricter, remove old code, fix up pydantic import when using v2 (#11231)
This commit is contained in:
parent
572968fee3
commit
b4354b7694
@ -15,7 +15,7 @@ from langchain.schema.runnable import Runnable
|
||||
from typing_extensions import Annotated
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel
|
||||
from pydantic.v1 import BaseModel, create_model
|
||||
except ImportError:
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
|
@ -7,6 +7,11 @@ from langchain.schema.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
try:
|
||||
from pydantic.v1 import BaseModel
|
||||
except ImportError:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langserve.serialization import simple_dumps, simple_loads
|
||||
|
||||
|
||||
@ -120,3 +125,31 @@ def test_serialization(data: Any, expected_json: Any) -> None:
|
||||
assert json.loads(simple_dumps(data)) == expected_json
|
||||
# Test decoding
|
||||
assert simple_loads(json.dumps(expected_json)) == data
|
||||
# Test full representation are equivalent including the pydantic model classes
|
||||
assert _get_full_representation(data) == _get_full_representation(
|
||||
simple_loads(json.dumps(expected_json))
|
||||
)
|
||||
|
||||
|
||||
def _get_full_representation(data: Any) -> Any:
|
||||
"""Get the full representation of the data, replacing pydantic models with schema.
|
||||
|
||||
Pydantic tests two different models for equality based on equality
|
||||
of their schema; instead we will rely on the equality of their full
|
||||
schema representation. This will make sure that both models have the
|
||||
same name (e.g., HumanMessage vs. HumanMessageChunk).
|
||||
|
||||
Args:
|
||||
data: python primitives + pydantic models
|
||||
|
||||
Returns:
|
||||
data represented entirely with python primitives
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {key: _get_full_representation(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_get_full_representation(value) for value in data]
|
||||
elif isinstance(data, BaseModel):
|
||||
return data.schema()
|
||||
else:
|
||||
return data
|
||||
|
@ -231,19 +231,6 @@ def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None:
|
||||
# assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4]
|
||||
|
||||
|
||||
def test_pydantic_root():
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Model(BaseModel):
|
||||
__root__: str
|
||||
|
||||
class Q(BaseModel):
|
||||
input: Model
|
||||
|
||||
# s = Model(__root__=[23])
|
||||
Q(input="hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None:
|
||||
"""Test as part of a sequence.
|
||||
|
Loading…
Reference in New Issue
Block a user