diff --git a/libs/langserve/langserve/server.py b/libs/langserve/langserve/server.py index 54e264be60c..cfae77e6b9f 100644 --- a/libs/langserve/langserve/server.py +++ b/libs/langserve/langserve/server.py @@ -15,7 +15,7 @@ from langchain.schema.runnable import Runnable from typing_extensions import Annotated try: - from pydantic.v1 import BaseModel + from pydantic.v1 import BaseModel, create_model except ImportError: from pydantic import BaseModel, create_model diff --git a/libs/langserve/tests/unit_tests/test_encoders.py b/libs/langserve/tests/unit_tests/test_encoders.py index ec1c4700fb3..b3c7374b3cd 100644 --- a/libs/langserve/tests/unit_tests/test_encoders.py +++ b/libs/langserve/tests/unit_tests/test_encoders.py @@ -7,6 +7,11 @@ from langchain.schema.messages import ( SystemMessage, ) +try: + from pydantic.v1 import BaseModel +except ImportError: + from pydantic import BaseModel + from langserve.serialization import simple_dumps, simple_loads @@ -120,3 +125,31 @@ def test_serialization(data: Any, expected_json: Any) -> None: assert json.loads(simple_dumps(data)) == expected_json # Test decoding assert simple_loads(json.dumps(expected_json)) == data + # Test full representation are equivalent including the pydantic model classes + assert _get_full_representation(data) == _get_full_representation( + simple_loads(json.dumps(expected_json)) + ) + + +def _get_full_representation(data: Any) -> Any: + """Get the full representation of the data, replacing pydantic models with schema. + + Pydantic tests two different models for equality based on equality + of their schema; instead we will rely on the equality of their full + schema representation. This will make sure that both models have the + same name (e.g., HumanMessage vs. HumanMessageChunk). + + Args: + data: python primitives + pydantic models + + Returns: + data represented entirely with python primitives + """ + if isinstance(data, dict): + return {key: _get_full_representation(value) for key, value in data.items()} + elif isinstance(data, list): + return [_get_full_representation(value) for value in data] + elif isinstance(data, BaseModel): + return data.schema() + else: + return data diff --git a/libs/langserve/tests/unit_tests/test_server_client.py b/libs/langserve/tests/unit_tests/test_server_client.py index 9c0b57d384e..875cdecea13 100644 --- a/libs/langserve/tests/unit_tests/test_server_client.py +++ b/libs/langserve/tests/unit_tests/test_server_client.py @@ -231,19 +231,6 @@ def test_invoke_as_part_of_sequence(client: RemoteRunnable) -> None: # assert list(runnable.stream([1, 2], config={"tags": ["test"]})) == [3, 4] -def test_pydantic_root(): - from pydantic import BaseModel - - class Model(BaseModel): - __root__: str - - class Q(BaseModel): - input: Model - - # s = Model(__root__=[23]) - Q(input="hello") - - @pytest.mark.asyncio async def test_invoke_as_part_of_sequence_async(async_client: RemoteRunnable) -> None: """Test as part of a sequence.