mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
core[patch]: Fix llm string representation for serializable models (#23416)
Fix LLM string representation for serializable objects. Fix for issue: https://github.com/langchain-ai/langchain/issues/23257 The llm string of serializable chat models is the serialized representation of the object. LangChain serialization dumps some basic information about non serializable objects including their repr() which includes an object id. This means that if a chat model has any non serializable fields (e.g., a cache), then any new instantiation of the those fields will change the llm representation of the chat model and cause chat misses. i.e., re-instantiating a postgres cache would result in cache misses!
This commit is contained in:
parent
3904f2cd40
commit
b5aef4cf97
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
@ -448,7 +449,11 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
if self.is_lc_serializable():
|
||||
params = {**kwargs, **{"stop": stop}}
|
||||
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
llm_string = dumps(self)
|
||||
# This code is not super efficient as it goes back and forth between
|
||||
# json and dict.
|
||||
serialized_repr = dumpd(self)
|
||||
_cleanup_llm_representation(serialized_repr, 1)
|
||||
llm_string = json.dumps(serialized_repr, sort_keys=True)
|
||||
return llm_string + "---" + param_string
|
||||
else:
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@ -1215,3 +1220,20 @@ def _gen_info_and_msg_metadata(
|
||||
**(generation.generation_info or {}),
|
||||
**generation.message.response_metadata,
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_llm_representation(serialized: Any, depth: int) -> None:
|
||||
"""Remove non-serializable objects from a serialized object."""
|
||||
if depth > 100: # Don't cooperate for pathological cases
|
||||
return
|
||||
if serialized["type"] == "not_implemented" and "repr" in serialized:
|
||||
del serialized["repr"]
|
||||
|
||||
if "graph" in serialized:
|
||||
del serialized["graph"]
|
||||
|
||||
if "kwargs" in serialized:
|
||||
kwargs = serialized["kwargs"]
|
||||
|
||||
for value in kwargs.values():
|
||||
_cleanup_llm_representation(value, depth + 1)
|
||||
|
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
|
||||
from langchain_core.globals import set_llm_cache
|
||||
from langchain_core.language_models.chat_models import _cleanup_llm_representation
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModel,
|
||||
GenericFakeChatModel,
|
||||
@ -266,3 +267,137 @@ def test_global_cache_stream() -> None:
|
||||
assert global_cache._cache != {}
|
||||
finally:
|
||||
set_llm_cache(None)
|
||||
|
||||
|
||||
class CustomChat(GenericFakeChatModel):
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def test_can_swap_caches() -> None:
|
||||
"""Test that we can use a different cache object.
|
||||
|
||||
This test verifies that when we fetch teh llm_string representation
|
||||
of the chat model, we can swap the cache object and still get the same
|
||||
result.
|
||||
"""
|
||||
cache = InMemoryCache()
|
||||
chat_model = CustomChat(cache=cache, messages=iter(["hello"]))
|
||||
result = await chat_model.ainvoke("foo")
|
||||
assert result.content == "hello"
|
||||
|
||||
new_cache = InMemoryCache()
|
||||
new_cache._cache = cache._cache.copy()
|
||||
|
||||
# Confirm that we get a cache hit!
|
||||
chat_model = CustomChat(cache=new_cache, messages=iter(["goodbye"]))
|
||||
result = await chat_model.ainvoke("foo")
|
||||
assert result.content == "hello"
|
||||
|
||||
|
||||
def test_llm_representation_for_serializable() -> None:
|
||||
"""Test that the llm representation of a serializable chat model is correct."""
|
||||
cache = InMemoryCache()
|
||||
chat = CustomChat(cache=cache, messages=iter([]))
|
||||
assert chat._get_llm_string() == (
|
||||
'{"id": ["tests", "unit_tests", "language_models", "chat_models", '
|
||||
'"test_cache", "CustomChat"], "kwargs": {"cache": {"id": ["tests", '
|
||||
'"unit_tests", "language_models", "chat_models", "test_cache", '
|
||||
'"InMemoryCache"], "lc": 1, "type": "not_implemented"}, "messages": {"id": '
|
||||
'["builtins", "list_iterator"], "lc": 1, "type": "not_implemented"}}, "lc": '
|
||||
'1, "name": "CustomChat", "type": "constructor"}---[(\'stop\', None)]'
|
||||
)
|
||||
|
||||
|
||||
def test_cleanup_serialized() -> None:
|
||||
cleanup_serialized = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"language_models",
|
||||
"chat_models",
|
||||
"test_cache",
|
||||
"CustomChat",
|
||||
],
|
||||
"kwargs": {
|
||||
"cache": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"language_models",
|
||||
"chat_models",
|
||||
"test_cache",
|
||||
"InMemoryCache",
|
||||
],
|
||||
"repr": "<tests.unit_tests.language_models.chat_models."
|
||||
"test_cache.InMemoryCache object at 0x79ff437fe7d0>",
|
||||
},
|
||||
"messages": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": ["builtins", "list_iterator"],
|
||||
"repr": "<list_iterator object at 0x79ff437f8d30>",
|
||||
},
|
||||
},
|
||||
"name": "CustomChat",
|
||||
"graph": {
|
||||
"nodes": [
|
||||
{"id": 0, "type": "schema", "data": "CustomChatInput"},
|
||||
{
|
||||
"id": 1,
|
||||
"type": "runnable",
|
||||
"data": {
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"language_models",
|
||||
"chat_models",
|
||||
"test_cache",
|
||||
"CustomChat",
|
||||
],
|
||||
"name": "CustomChat",
|
||||
},
|
||||
},
|
||||
{"id": 2, "type": "schema", "data": "CustomChatOutput"},
|
||||
],
|
||||
"edges": [{"source": 0, "target": 1}, {"source": 1, "target": 2}],
|
||||
},
|
||||
}
|
||||
_cleanup_llm_representation(cleanup_serialized, 1)
|
||||
assert cleanup_serialized == {
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"language_models",
|
||||
"chat_models",
|
||||
"test_cache",
|
||||
"CustomChat",
|
||||
],
|
||||
"kwargs": {
|
||||
"cache": {
|
||||
"id": [
|
||||
"tests",
|
||||
"unit_tests",
|
||||
"language_models",
|
||||
"chat_models",
|
||||
"test_cache",
|
||||
"InMemoryCache",
|
||||
],
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
},
|
||||
"messages": {
|
||||
"id": ["builtins", "list_iterator"],
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
},
|
||||
},
|
||||
"lc": 1,
|
||||
"name": "CustomChat",
|
||||
"type": "constructor",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user