mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 07:26:16 +00:00
refactor: extract cached generations conversion to a separate method
This commit is contained in:
@@ -11,22 +11,9 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator, Iterator, Sequence
|
||||
from functools import cached_property
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, cast
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
@@ -654,6 +641,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict: # noqa: ARG002
|
||||
return {}
|
||||
|
||||
def _convert_cached_generations(self, cache_val: list) -> list[ChatGeneration]:
|
||||
"""Convert cached Generation objects to ChatGeneration objects.
|
||||
|
||||
Handle case where cache contains Generation objects instead of
|
||||
ChatGeneration objects. This can happen due to serialization/deserialization
|
||||
issues or legacy cache data (see #22389).
|
||||
|
||||
Args:
|
||||
cache_val: List of cached generation objects.
|
||||
|
||||
Returns:
|
||||
List of ChatGeneration objects.
|
||||
"""
|
||||
converted_generations = []
|
||||
for gen in cache_val:
|
||||
if isinstance(gen, Generation) and not isinstance(gen, ChatGeneration):
|
||||
# Convert Generation to ChatGeneration by creating AIMessage
|
||||
# from the text content
|
||||
chat_gen = ChatGeneration(
|
||||
message=AIMessage(content=gen.text),
|
||||
generation_info=gen.generation_info,
|
||||
)
|
||||
converted_generations.append(chat_gen)
|
||||
else:
|
||||
# Already a ChatGeneration or other expected type
|
||||
converted_generations.append(gen)
|
||||
return converted_generations
|
||||
|
||||
def _get_invocation_params(
|
||||
self,
|
||||
stop: Optional[list[str]] = None,
|
||||
@@ -1011,25 +1026,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
# Handle case where cache contains Generation objects instead of
|
||||
# ChatGeneration objects
|
||||
# This can happen due to serialization/deserialization issues or
|
||||
# legacy cache data
|
||||
converted_generations = []
|
||||
for gen in cache_val:
|
||||
if isinstance(gen, Generation) and not isinstance(
|
||||
gen, ChatGeneration
|
||||
):
|
||||
# Convert Generation to ChatGeneration by creating AIMessage
|
||||
# from the text content
|
||||
chat_gen = ChatGeneration(
|
||||
message=AIMessage(content=gen.text),
|
||||
generation_info=gen.generation_info,
|
||||
)
|
||||
converted_generations.append(chat_gen)
|
||||
else:
|
||||
# Already a ChatGeneration or other expected type
|
||||
converted_generations.append(gen)
|
||||
converted_generations = self._convert_cached_generations(cache_val)
|
||||
return ChatResult(generations=converted_generations)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
@@ -1102,25 +1099,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
prompt = dumps(messages)
|
||||
cache_val = await llm_cache.alookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
# Handle case where cache contains Generation objects instead of
|
||||
# ChatGeneration objects
|
||||
# This can happen due to serialization/deserialization issues or
|
||||
# legacy cache data
|
||||
converted_generations = []
|
||||
for gen in cache_val:
|
||||
if isinstance(gen, Generation) and not isinstance(
|
||||
gen, ChatGeneration
|
||||
):
|
||||
# Convert Generation to ChatGeneration by creating AIMessage
|
||||
# from the text content
|
||||
chat_gen = ChatGeneration(
|
||||
message=AIMessage(content=gen.text),
|
||||
generation_info=gen.generation_info,
|
||||
)
|
||||
converted_generations.append(chat_gen)
|
||||
else:
|
||||
# Already a ChatGeneration or other expected type
|
||||
converted_generations.append(gen)
|
||||
converted_generations = self._convert_cached_generations(cache_val)
|
||||
return ChatResult(generations=converted_generations)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
|
Reference in New Issue
Block a user