diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index d37ca232052..c5e7b43cebf 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -11,22 +11,9 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator, Sequence from functools import cached_property from operator import itemgetter -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast -from pydantic import ( - BaseModel, - ConfigDict, - Field, - model_validator, -) +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import override from langchain_core._api import deprecated @@ -63,6 +50,7 @@ from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, ChatResult, + Generation, LLMResult, RunInfo, ) @@ -653,6 +641,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002 return {} + def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]: + """Convert cached Generation objects to ChatGeneration objects. + + Handle case where cache contains Generation objects instead of + ChatGeneration objects. This can happen due to serialization/deserialization + issues or legacy cache data (see #22389). + + Args: + cache_val: List of cached generation objects. + + Returns: + List of ChatGeneration objects. + """ + converted_generations = [] + for gen in cache_val: + if isinstance(gen, Generation) and not isinstance(gen, ChatGeneration): + # Convert Generation to ChatGeneration by creating AIMessage + # from the text content + chat_gen = ChatGeneration( + message=AIMessage(content=gen.text), + generation_info=gen.generation_info, + ) + converted_generations.append(chat_gen) + else: + # Already a ChatGeneration or other expected type + converted_generations.append(gen) + return converted_generations + def _get_invocation_params( self, stop: Optional[list[str]] = None, @@ -1010,7 +1026,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): prompt = dumps(messages) cache_val = llm_cache.lookup(prompt, llm_string) if isinstance(cache_val, list): - return ChatResult(generations=cache_val) + converted_generations = self._convert_cached_generations(cache_val) + return ChatResult(generations=converted_generations) elif self.cache is None: pass else: @@ -1082,7 +1099,8 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): prompt = dumps(messages) cache_val = await llm_cache.alookup(prompt, llm_string) if isinstance(cache_val, list): - return ChatResult(generations=cache_val) + converted_generations = self._convert_cached_generations(cache_val) + return ChatResult(generations=converted_generations) elif self.cache is None: pass else: diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py index 1cceb0a146b..2c9b5608c91 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_cache.py @@ -13,7 +13,8 @@ from langchain_core.language_models.fake_chat_models import ( GenericFakeChatModel, ) from langchain_core.messages import AIMessage -from langchain_core.outputs import ChatGeneration +from langchain_core.outputs import ChatGeneration, Generation +from langchain_core.outputs.chat_result import ChatResult class InMemoryCache(BaseCache): @@ -305,6 +306,93 @@ def test_llm_representation_for_serializable() -> None: ) +def test_cache_with_generation_objects() -> None: + """Test that cache can handle Generation objects instead of ChatGeneration objects. + + This test reproduces a bug where cache returns Generation objects + but ChatResult expects ChatGeneration objects, causing validation errors. + + See #22389 for more info. + + """ + cache = InMemoryCache() + + # Create a simple fake chat model that we can control + from langchain_core.messages import AIMessage + + class SimpleFakeChat: + """Simple fake chat model for testing.""" + + def __init__(self, cache: BaseCache) -> None: + self.cache = cache + self.response = "hello" + + def _get_llm_string(self) -> str: + return "test_llm_string" + + def generate_response(self, prompt: str) -> ChatResult: + """Simulate the cache lookup and generation logic.""" + from langchain_core.load import dumps + + llm_string = self._get_llm_string() + prompt_str = dumps([prompt]) + + # Check cache first + cache_val = self.cache.lookup(prompt_str, llm_string) + if cache_val: + # This is where our fix should work + converted_generations = [] + for gen in cache_val: + if isinstance(gen, Generation) and not isinstance( + gen, ChatGeneration + ): + # Convert Generation to ChatGeneration by creating an AIMessage + chat_gen = ChatGeneration( + message=AIMessage(content=gen.text), + generation_info=gen.generation_info, + ) + converted_generations.append(chat_gen) + else: + converted_generations.append(gen) + return ChatResult(generations=converted_generations) + + # Generate new response + chat_gen = ChatGeneration( + message=AIMessage(content=self.response), generation_info={} + ) + result = ChatResult(generations=[chat_gen]) + + # Store in cache + self.cache.update(prompt_str, llm_string, result.generations) + return result + + model = SimpleFakeChat(cache) + + # First call - normal operation + result1 = model.generate_response("test prompt") + assert result1.generations[0].message.content == "hello" + + # Manually corrupt the cache by replacing ChatGeneration with Generation + cache_key = next(iter(cache._cache.keys())) + cached_chat_generations = cache._cache[cache_key] + + # Replace with Generation objects (missing message field) + corrupted_generations = [ + Generation( + text=gen.text, + generation_info=gen.generation_info, + type="Generation", # This is the key - wrong type + ) + for gen in cached_chat_generations + ] + cache._cache[cache_key] = corrupted_generations + + # Second call should handle the Generation objects gracefully + result2 = model.generate_response("test prompt") + assert result2.generations[0].message.content == "hello" + assert isinstance(result2.generations[0], ChatGeneration) + + def test_cleanup_serialized() -> None: cleanup_serialized = { "lc": 1,