diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 51fff5dc3e4..8dd6dcc7719 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import inspect +import json import uuid import warnings from abc import ABC, abstractmethod @@ -448,7 +449,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.is_lc_serializable(): params = {**kwargs, **{"stop": stop}} param_string = str(sorted([(k, v) for k, v in params.items()])) - llm_string = dumps(self) + # This code is not super efficient as it goes back and forth between + # json and dict. + serialized_repr = dumpd(self) + _cleanup_llm_representation(serialized_repr, 1) + llm_string = json.dumps(serialized_repr, sort_keys=True) return llm_string + "---" + param_string else: params = self._get_invocation_params(stop=stop, **kwargs) @@ -1215,3 +1220,20 @@ def _gen_info_and_msg_metadata( **(generation.generation_info or {}), **generation.message.response_metadata, } + + +def _cleanup_llm_representation(serialized: Any, depth: int) -> None: + """Remove non-serializable objects from a serialized object.""" + if depth > 100: # Don't cooperate for pathological cases + return + if serialized["type"] == "not_implemented" and "repr" in serialized: + del serialized["repr"] + + if "graph" in serialized: + del serialized["graph"] + + if "kwargs" in serialized: + kwargs = serialized["kwargs"] + + for value in kwargs.values(): + _cleanup_llm_representation(value, depth + 1) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index a2cc4fc4591..51d18b59c1e 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -5,6 +5,7 @@ import pytest from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.globals import set_llm_cache +from langchain_core.language_models.chat_models import _cleanup_llm_representation from langchain_core.language_models.fake_chat_models import ( FakeListChatModel, GenericFakeChatModel, @@ -266,3 +267,137 @@ def test_global_cache_stream() -> None: assert global_cache._cache != {} finally: set_llm_cache(None) + + +class CustomChat(GenericFakeChatModel): + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +async def test_can_swap_caches() -> None: + """Test that we can use a different cache object. + + This test verifies that when we fetch teh llm_string representation + of the chat model, we can swap the cache object and still get the same + result. + """ + cache = InMemoryCache() + chat_model = CustomChat(cache=cache, messages=iter(["hello"])) + result = await chat_model.ainvoke("foo") + assert result.content == "hello" + + new_cache = InMemoryCache() + new_cache._cache = cache._cache.copy() + + # Confirm that we get a cache hit! + chat_model = CustomChat(cache=new_cache, messages=iter(["goodbye"])) + result = await chat_model.ainvoke("foo") + assert result.content == "hello" + + +def test_llm_representation_for_serializable() -> None: + """Test that the llm representation of a serializable chat model is correct.""" + cache = InMemoryCache() + chat = CustomChat(cache=cache, messages=iter([])) + assert chat._get_llm_string() == ( + '{"id": ["tests", "unit_tests", "language_models", "chat_models", ' + '"test_cache", "CustomChat"], "kwargs": {"cache": {"id": ["tests", ' + '"unit_tests", "language_models", "chat_models", "test_cache", ' + '"InMemoryCache"], "lc": 1, "type": "not_implemented"}, "messages": {"id": ' + '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": ' + '1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]' + ) + + +def test_cleanup_serialized() -> None: + cleanup_serialized = { + "lc": 1, + "type": "constructor", + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "kwargs": { + "cache": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "InMemoryCache", + ], + "repr": "", + }, + "messages": { + "lc": 1, + "type": "not_implemented", + "id": ["builtins", "list_iterator"], + "repr": "", + }, + }, + "name": "CustomChat", + "graph": { + "nodes": [ + {"id": 0, "type": "schema", "data": "CustomChatInput"}, + { + "id": 1, + "type": "runnable", + "data": { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "name": "CustomChat", + }, + }, + {"id": 2, "type": "schema", "data": "CustomChatOutput"}, + ], + "edges": [{"source": 0, "target": 1}, {"source": 1, "target": 2}], + }, + } + _cleanup_llm_representation(cleanup_serialized, 1) + assert cleanup_serialized == { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "kwargs": { + "cache": { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "InMemoryCache", + ], + "lc": 1, + "type": "not_implemented", + }, + "messages": { + "id": ["builtins", "list_iterator"], + "lc": 1, + "type": "not_implemented", + }, + }, + "lc": 1, + "name": "CustomChat", + "type": "constructor", + }