From b5aef4cf9740badac84733240e8f53737283013d Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 1 Jul 2024 14:06:33 -0400 Subject: [PATCH] core[patch]: Fix llm string representation for serializable models (#23416) Fix LLM string representation for serializable objects. Fix for issue: https://github.com/langchain-ai/langchain/issues/23257 The llm string of serializable chat models is the serialized representation of the object. LangChain serialization dumps some basic information about non serializable objects including their repr() which includes an object id. This means that if a chat model has any non serializable fields (e.g., a cache), then any new instantiation of the those fields will change the llm representation of the chat model and cause chat misses. i.e., re-instantiating a postgres cache would result in cache misses! --- .../language_models/chat_models.py | 24 +++- .../language_models/chat_models/test_cache.py | 135 ++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 51fff5dc3e4..8dd6dcc7719 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import inspect +import json import uuid import warnings from abc import ABC, abstractmethod @@ -448,7 +449,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.is_lc_serializable(): params = {**kwargs, **{"stop": stop}} param_string = str(sorted([(k, v) for k, v in params.items()])) - llm_string = dumps(self) + # This code is not super efficient as it goes back and forth between + # json and dict. + serialized_repr = dumpd(self) + _cleanup_llm_representation(serialized_repr, 1) + llm_string = json.dumps(serialized_repr, sort_keys=True) return llm_string + "---" + param_string else: params = self._get_invocation_params(stop=stop, **kwargs) @@ -1215,3 +1220,20 @@ def _gen_info_and_msg_metadata( **(generation.generation_info or {}), **generation.message.response_metadata, } + + +def _cleanup_llm_representation(serialized: Any, depth: int) -> None: + """Remove non-serializable objects from a serialized object.""" + if depth > 100: # Don't cooperate for pathological cases + return + if serialized["type"] == "not_implemented" and "repr" in serialized: + del serialized["repr"] + + if "graph" in serialized: + del serialized["graph"] + + if "kwargs" in serialized: + kwargs = serialized["kwargs"] + + for value in kwargs.values(): + _cleanup_llm_representation(value, depth + 1) diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index a2cc4fc4591..51d18b59c1e 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -5,6 +5,7 @@ import pytest from langchain_core.caches import RETURN_VAL_TYPE, BaseCache from langchain_core.globals import set_llm_cache +from langchain_core.language_models.chat_models import _cleanup_llm_representation from langchain_core.language_models.fake_chat_models import ( FakeListChatModel, GenericFakeChatModel, @@ -266,3 +267,137 @@ def test_global_cache_stream() -> None: assert global_cache._cache != {} finally: set_llm_cache(None) + + +class CustomChat(GenericFakeChatModel): + @classmethod + def is_lc_serializable(cls) -> bool: + return True + + +async def test_can_swap_caches() -> None: + """Test that we can use a different cache object. + + This test verifies that when we fetch teh llm_string representation + of the chat model, we can swap the cache object and still get the same + result. + """ + cache = InMemoryCache() + chat_model = CustomChat(cache=cache, messages=iter(["hello"])) + result = await chat_model.ainvoke("foo") + assert result.content == "hello" + + new_cache = InMemoryCache() + new_cache._cache = cache._cache.copy() + + # Confirm that we get a cache hit! + chat_model = CustomChat(cache=new_cache, messages=iter(["goodbye"])) + result = await chat_model.ainvoke("foo") + assert result.content == "hello" + + +def test_llm_representation_for_serializable() -> None: + """Test that the llm representation of a serializable chat model is correct.""" + cache = InMemoryCache() + chat = CustomChat(cache=cache, messages=iter([])) + assert chat._get_llm_string() == ( + '{"id": ["tests", "unit_tests", "language_models", "chat_models", ' + '"test_cache", "CustomChat"], "kwargs": {"cache": {"id": ["tests", ' + '"unit_tests", "language_models", "chat_models", "test_cache", ' + '"InMemoryCache"], "lc": 1, "type": "not_implemented"}, "messages": {"id": ' + '["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": ' + '1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]' + ) + + +def test_cleanup_serialized() -> None: + cleanup_serialized = { + "lc": 1, + "type": "constructor", + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "kwargs": { + "cache": { + "lc": 1, + "type": "not_implemented", + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "InMemoryCache", + ], + "repr": "", + }, + "messages": { + "lc": 1, + "type": "not_implemented", + "id": ["builtins", "list_iterator"], + "repr": "", + }, + }, + "name": "CustomChat", + "graph": { + "nodes": [ + {"id": 0, "type": "schema", "data": "CustomChatInput"}, + { + "id": 1, + "type": "runnable", + "data": { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "name": "CustomChat", + }, + }, + {"id": 2, "type": "schema", "data": "CustomChatOutput"}, + ], + "edges": [{"source": 0, "target": 1}, {"source": 1, "target": 2}], + }, + } + _cleanup_llm_representation(cleanup_serialized, 1) + assert cleanup_serialized == { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "CustomChat", + ], + "kwargs": { + "cache": { + "id": [ + "tests", + "unit_tests", + "language_models", + "chat_models", + "test_cache", + "InMemoryCache", + ], + "lc": 1, + "type": "not_implemented", + }, + "messages": { + "id": ["builtins", "list_iterator"], + "lc": 1, + "type": "not_implemented", + }, + }, + "lc": 1, + "name": "CustomChat", + "type": "constructor", + }