core: Cache serialized manifest for llms, chat models and prompts (#26404)

This commit is contained in:
Nuno Campos
2024-09-12 14:57:06 -07:00
committed by GitHub
parent 85f673c7ea
commit d2a69a7b6b
4 changed files with 31 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ import json
import uuid import uuid
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@@ -247,6 +248,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# --- Runnable methods --- # --- Runnable methods ---
@property @property
@@ -378,7 +383,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = callback_manager.on_chat_model_start( (run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), self._serialized,
[messages], [messages],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -450,7 +455,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = await callback_manager.on_chat_model_start( (run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self), self._serialized,
[messages], [messages],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -551,7 +556,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
param_string = str(sorted([(k, v) for k, v in params.items()])) param_string = str(sorted([(k, v) for k, v in params.items()]))
# This code is not super efficient as it goes back and forth between # This code is not super efficient as it goes back and forth between
# json and dict. # json and dict.
serialized_repr = dumpd(self) serialized_repr = self._serialized
_cleanup_llm_representation(serialized_repr, 1) _cleanup_llm_representation(serialized_repr, 1)
llm_string = json.dumps(serialized_repr, sort_keys=True) llm_string = json.dumps(serialized_repr, sort_keys=True)
return llm_string + "---" + param_string return llm_string + "---" + param_string
@@ -613,7 +618,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
self.metadata, self.metadata,
) )
run_managers = callback_manager.on_chat_model_start( run_managers = callback_manager.on_chat_model_start(
dumpd(self), self._serialized,
messages, messages,
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -705,7 +710,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
run_managers = await callback_manager.on_chat_model_start( run_managers = await callback_manager.on_chat_model_start(
dumpd(self), self._serialized,
messages, messages,
invocation_params=params, invocation_params=params,
options=options, options=options,

View File

@@ -317,6 +317,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
values["callbacks"] = values.pop("callback_manager", None) values["callbacks"] = values.pop("callback_manager", None)
return values return values
@functools.cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# --- Runnable methods --- # --- Runnable methods ---
@property @property
@@ -544,7 +548,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = callback_manager.on_llm_start( (run_manager,) = callback_manager.on_llm_start(
dumpd(self), self._serialized,
[prompt], [prompt],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -609,7 +613,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.metadata, self.metadata,
) )
(run_manager,) = await callback_manager.on_llm_start( (run_manager,) = await callback_manager.on_llm_start(
dumpd(self), self._serialized,
[prompt], [prompt],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -931,7 +935,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if (self.cache is None and get_llm_cache() is None) or self.cache is False: if (self.cache is None and get_llm_cache() is None) or self.cache is False:
run_managers = [ run_managers = [
callback_manager.on_llm_start( callback_manager.on_llm_start(
dumpd(self), self._serialized,
[prompt], [prompt],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -950,7 +954,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_managers = [ run_managers = [
callback_managers[idx].on_llm_start( callback_managers[idx].on_llm_start(
dumpd(self), self._serialized,
[prompts[idx]], [prompts[idx]],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -1168,7 +1172,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather( run_managers = await asyncio.gather(
*[ *[
callback_manager.on_llm_start( callback_manager.on_llm_start(
dumpd(self), self._serialized,
[prompt], [prompt],
invocation_params=params, invocation_params=params,
options=options, options=options,
@@ -1194,7 +1198,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_managers = await asyncio.gather( run_managers = await asyncio.gather(
*[ *[
callback_managers[idx].on_llm_start( callback_managers[idx].on_llm_start(
dumpd(self), self._serialized,
[prompts[idx]], [prompts[idx]],
invocation_params=params, invocation_params=params,
options=options, options=options,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@@ -103,6 +104,10 @@ class BasePromptTemplate(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
) )
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
"""Return the output type of the prompt.""" """Return the output type of the prompt."""
@@ -190,7 +195,7 @@ class BasePromptTemplate(
input, input,
config, config,
run_type="prompt", run_type="prompt",
serialized=dumpd(self), serialized=self._serialized,
) )
async def ainvoke( async def ainvoke(
@@ -215,7 +220,7 @@ class BasePromptTemplate(
input, input,
config, config,
run_type="prompt", run_type="prompt",
serialized=dumpd(self), serialized=self._serialized,
) )
@abstractmethod @abstractmethod

File diff suppressed because one or more lines are too long