core: Cache serialized manifest for llms, chat models and prompts (#26404)

This commit is contained in:
Nuno Campos
2024-09-12 14:57:06 -07:00
committed by GitHub
parent 85f673c7ea
commit d2a69a7b6b
4 changed files with 31 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ import json
import uuid
import warnings
from abc import ABC, abstractmethod
from functools import cached_property
from operator import itemgetter
from typing import (
TYPE_CHECKING,
@@ -247,6 +248,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
arbitrary_types_allowed=True,
)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# --- Runnable methods ---
@property
@@ -378,7 +383,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self),
self._serialized,
[messages],
invocation_params=params,
options=options,
@@ -450,7 +455,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self),
self._serialized,
[messages],
invocation_params=params,
options=options,
@@ -551,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
param_string = str(sorted([(k, v) for k, v in params.items()]))
# This code is not super efficient as it goes back and forth between
# json and dict.
serialized_repr = dumpd(self)
serialized_repr = self._serialized
_cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string
@@ -613,7 +618,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self),
self._serialized,
messages,
invocation_params=params,
options=options,
@@ -705,7 +710,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
run_managers = await callback_manager.on_chat_model_start(
dumpd(self),
self._serialized,
messages,
invocation_params=params,
options=options,

View File

@@ -317,6 +317,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
values["callbacks"] = values.pop("callback_manager", None)
return values
@functools.cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# --- Runnable methods ---
@property
@@ -544,7 +548,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self),
self._serialized,
[prompt],
invocation_params=params,
options=options,
@@ -609,7 +613,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
self._serialized,
[prompt],
invocation_params=params,
options=options,
@@ -931,7 +935,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if (self.cache is None and get_llm_cache() is None) or self.cache is False:
run_managers = [
callback_manager.on_llm_start(
dumpd(self),
self._serialized,
[prompt],
invocation_params=params,
options=options,
@@ -950,7 +954,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if len(missing_prompts) > 0:
run_managers = [
callback_managers[idx].on_llm_start(
dumpd(self),
self._serialized,
[prompts[idx]],
invocation_params=params,
options=options,
@@ -1168,7 +1172,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather(
*[
callback_manager.on_llm_start(
dumpd(self),
self._serialized,
[prompt],
invocation_params=params,
options=options,
@@ -1194,7 +1198,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather(
*[
callback_managers[idx].on_llm_start(
dumpd(self),
self._serialized,
[prompts[idx]],
invocation_params=params,
options=options,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -103,6 +104,10 @@ class BasePromptTemplate(
arbitrary_types_allowed=True,
)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
@property
def OutputType(self) -> Any:
"""Return the output type of the prompt."""
@@ -190,7 +195,7 @@ class BasePromptTemplate(
input,
config,
run_type="prompt",
serialized=dumpd(self),
serialized=self._serialized,
)
async def ainvoke(
@@ -215,7 +220,7 @@ class BasePromptTemplate(
input,
config,
run_type="prompt",
serialized=dumpd(self),
serialized=self._serialized,
)
@abstractmethod

File diff suppressed because one or more lines are too long