diff --git a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py index 6ad014bab3f..1d7444ee4ba 100644 --- a/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_standard_tests/unit_tests/chat_models.py @@ -5,13 +5,12 @@ from typing import Any, List, Literal, Optional, Tuple, Type from unittest import mock import pytest -from syrupy import SnapshotAssertion - from langchain_core.language_models import BaseChatModel from langchain_core.load import dumpd, load from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.runnables import RunnableBinding from langchain_core.tools import tool +from syrupy import SnapshotAssertion from langchain_standard_tests.base import BaseStandardTests from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION @@ -231,6 +230,8 @@ class ChatModelUnitTests(ChatModelTests): def test_serdes(self, model: BaseChatModel, snapshot: SnapshotAssertion) -> None: if not self.chat_model_class.is_lc_serializable(): return - ser = dumpd(model) - assert ser == snapshot(name="serialized") - assert model.dict() == load(dumpd(model)).dict() + env_params, model_params, expected_attrs = self.init_from_env_params + with mock.patch.dict(os.environ, env_params): + ser = dumpd(model) + assert ser == snapshot(name="serialized") + assert model.dict() == load(dumpd(model)).dict()