refactor: extract cached generations conversion to a separate method

This commit is contained in:
Mason Daugherty
2025-07-28 18:11:40 -04:00
parent 29aabcd6d9
commit a384c0d052

View File

@@ -11,22 +11,9 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator, Sequence
from functools import cached_property
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
)
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import override
from langchain_core._api import deprecated
@@ -654,6 +641,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
return {}
def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
"""Convert cached Generation objects to ChatGeneration objects.
Handle case where cache contains Generation objects instead of
ChatGeneration objects. This can happen due to serialization/deserialization
issues or legacy cache data (see #22389).
Args:
cache_val: List of cached generation objects.
Returns:
List of ChatGeneration objects.
"""
converted_generations = []
for gen in cache_val:
if isinstance(gen, Generation) and not isinstance(gen, ChatGeneration):
# Convert Generation to ChatGeneration by creating AIMessage
# from the text content
chat_gen = ChatGeneration(
message=AIMessage(content=gen.text),
generation_info=gen.generation_info,
)
converted_generations.append(chat_gen)
else:
# Already a ChatGeneration or other expected type
converted_generations.append(gen)
return converted_generations
def _get_invocation_params(
self,
stop: Optional[list[str]] = None,
@@ -1011,25 +1026,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
# Handle case where cache contains Generation objects instead of
# ChatGeneration objects
# This can happen due to serialization/deserialization issues or
# legacy cache data
converted_generations = []
for gen in cache_val:
if isinstance(gen, Generation) and not isinstance(
gen, ChatGeneration
):
# Convert Generation to ChatGeneration by creating AIMessage
# from the text content
chat_gen = ChatGeneration(
message=AIMessage(content=gen.text),
generation_info=gen.generation_info,
)
converted_generations.append(chat_gen)
else:
# Already a ChatGeneration or other expected type
converted_generations.append(gen)
converted_generations = self._convert_cached_generations(cache_val)
return ChatResult(generations=converted_generations)
elif self.cache is None:
pass
@@ -1102,25 +1099,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
# Handle case where cache contains Generation objects instead of
# ChatGeneration objects
# This can happen due to serialization/deserialization issues or
# legacy cache data
converted_generations = []
for gen in cache_val:
if isinstance(gen, Generation) and not isinstance(
gen, ChatGeneration
):
# Convert Generation to ChatGeneration by creating AIMessage
# from the text content
chat_gen = ChatGeneration(
message=AIMessage(content=gen.text),
generation_info=gen.generation_info,
)
converted_generations.append(chat_gen)
else:
# Already a ChatGeneration or other expected type
converted_generations.append(gen)
converted_generations = self._convert_cached_generations(cache_val)
return ChatResult(generations=converted_generations)
elif self.cache is None:
pass