core[patch]: Fix llm string representation for serializable models (#23416)

Fix LLM string representation for serializable objects.

Fix for issue: https://github.com/langchain-ai/langchain/issues/23257

The llm string of serializable chat models is the serialized
representation of the object. LangChain serialization dumps some basic
information about non serializable objects including their repr() which
includes an object id.

This means that if a chat model has any non serializable fields (e.g., a
cache), then any new instantiation of the those fields will change the
llm representation of the chat model and cause chat misses.

i.e., re-instantiating a postgres cache would result in cache misses!
This commit is contained in:
Eugene Yurtsev 2024-07-01 14:06:33 -04:00 committed by GitHub
parent 3904f2cd40
commit b5aef4cf97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 158 additions and 1 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import inspect
import json
import uuid
import warnings
from abc import ABC, abstractmethod
@ -448,7 +449,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.is_lc_serializable():
params = {**kwargs, **{"stop": stop}}
param_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = dumps(self)
# This code is not super efficient as it goes back and forth between
# json and dict.
serialized_repr = dumpd(self)
_cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string
else:
params = self._get_invocation_params(stop=stop, **kwargs)
@ -1215,3 +1220,20 @@ def _gen_info_and_msg_metadata(
**(generation.generation_info or {}),
**generation.message.response_metadata,
}
def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
"""Remove non-serializable objects from a serialized object."""
if depth > 100: # Don't cooperate for pathological cases
return
if serialized["type"] == "not_implemented" and "repr" in serialized:
del serialized["repr"]
if "graph" in serialized:
del serialized["graph"]
if "kwargs" in serialized:
kwargs = serialized["kwargs"]
for value in kwargs.values():
_cleanup_llm_representation(value, depth + 1)

View File

@ -5,6 +5,7 @@ import pytest
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.globals import set_llm_cache
from langchain_core.language_models.chat_models import _cleanup_llm_representation
from langchain_core.language_models.fake_chat_models import (
FakeListChatModel,
GenericFakeChatModel,
@ -266,3 +267,137 @@ def test_global_cache_stream() -> None:
assert global_cache._cache != {}
finally:
set_llm_cache(None)
class CustomChat(GenericFakeChatModel):
@classmethod
def is_lc_serializable(cls) -> bool:
return True
async def test_can_swap_caches() -> None:
"""Test that we can use a different cache object.
This test verifies that when we fetch teh llm_string representation
of the chat model, we can swap the cache object and still get the same
result.
"""
cache = InMemoryCache()
chat_model = CustomChat(cache=cache, messages=iter(["hello"]))
result = await chat_model.ainvoke("foo")
assert result.content == "hello"
new_cache = InMemoryCache()
new_cache._cache = cache._cache.copy()
# Confirm that we get a cache hit!
chat_model = CustomChat(cache=new_cache, messages=iter(["goodbye"]))
result = await chat_model.ainvoke("foo")
assert result.content == "hello"
def test_llm_representation_for_serializable() -> None:
"""Test that the llm representation of a serializable chat model is correct."""
cache = InMemoryCache()
chat = CustomChat(cache=cache, messages=iter([]))
assert chat._get_llm_string() == (
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
'"test_cache", "CustomChat"], "kwargs": {"cache": {"id": ["tests", '
'"unit_tests", "language_models", "chat_models", "test_cache", '
'"InMemoryCache"], "lc": 1, "type": "not_implemented"}, "messages": {"id": '
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
'1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
)
def test_cleanup_serialized() -> None:
cleanup_serialized = {
"lc": 1,
"type": "constructor",
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"kwargs": {
"cache": {
"lc": 1,
"type": "not_implemented",
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"InMemoryCache",
],
"repr": "<tests.unit_tests.language_models.chat_models."
"test_cache.InMemoryCache object at 0x79ff437fe7d0>",
},
"messages": {
"lc": 1,
"type": "not_implemented",
"id": ["builtins", "list_iterator"],
"repr": "<list_iterator object at 0x79ff437f8d30>",
},
},
"name": "CustomChat",
"graph": {
"nodes": [
{"id": 0, "type": "schema", "data": "CustomChatInput"},
{
"id": 1,
"type": "runnable",
"data": {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"name": "CustomChat",
},
},
{"id": 2, "type": "schema", "data": "CustomChatOutput"},
],
"edges": [{"source": 0, "target": 1}, {"source": 1, "target": 2}],
},
}
_cleanup_llm_representation(cleanup_serialized, 1)
assert cleanup_serialized == {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"CustomChat",
],
"kwargs": {
"cache": {
"id": [
"tests",
"unit_tests",
"language_models",
"chat_models",
"test_cache",
"InMemoryCache",
],
"lc": 1,
"type": "not_implemented",
},
"messages": {
"id": ["builtins", "list_iterator"],
"lc": 1,
"type": "not_implemented",
},
},
"lc": 1,
"name": "CustomChat",
"type": "constructor",
}