add methods to deserialize prompts that were old (#14857)

This commit is contained in:
Harrison Chase 2023-12-18 13:45:08 -08:00 committed by GitHub
parent 714bef0cb6
commit 193f107cb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 169 additions and 4 deletions

View File

@ -3,11 +3,16 @@ import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
from langchain_core.load.mapping import (
OLD_PROMPT_TEMPLATE_FORMATS,
SERIALIZABLE_MAPPING,
)
from langchain_core.load.serializable import Serializable
DEFAULT_NAMESPACES = ["langchain", "langchain_core", "langchain_community"]
ALL_SERIALIZABLE_MAPPINGS = {**SERIALIZABLE_MAPPING, **OLD_PROMPT_TEMPLATE_FORMATS}
class Reviver:
"""Reviver for JSON objects."""
@ -67,13 +72,13 @@ class Reviver:
if namespace[0] in DEFAULT_NAMESPACES:
# Get the importable path
key = tuple(namespace + [name])
if key not in SERIALIZABLE_MAPPING:
if key not in ALL_SERIALIZABLE_MAPPINGS:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{key}"
)
import_path = SERIALIZABLE_MAPPING[key]
import_path = ALL_SERIALIZABLE_MAPPINGS[key]
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module

View File

@ -476,3 +476,162 @@ SERIALIZABLE_MAPPING = {
"RunnableRetry",
),
}
# Needed for backwards compatibility for a few versions where we serialized
# with langchain_core
OLD_PROMPT_TEMPLATE_FORMATS = {
(
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
): (
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
),
(
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
): (
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
): (
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
),
(
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
): (
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
),
(
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
): (
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
),
(
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
): (
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
),
(
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
): (
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
),
}

View File

@ -19,7 +19,8 @@ def test_interfaces() -> None:
def _get_get_session_history(
*, store: Optional[Dict[str, Any]] = None
*,
store: Optional[Dict[str, Any]] = None,
) -> Callable[..., ChatMessageHistory]:
chat_history_store = store if store is not None else {}