core: mustache prompt templates (#19980)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Nuno Campos
2024-04-10 11:25:32 -07:00
committed by GitHub
parent 4cb5f4c353
commit 15271ac832
6 changed files with 904 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ from typing import (
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
@@ -929,6 +930,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@@ -964,7 +966,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
_messages = [
_convert_to_message(message, template_format) for message in messages
]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
@@ -1121,7 +1125,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type(
message_type: str, template: Union[str, list]
message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
@@ -1134,12 +1140,16 @@ def _create_template_from_message_type(
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
template, template_format=template_format
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(cast(str, template))
message = AIMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(cast(str, template))
message = SystemMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "placeholder":
if isinstance(template, str):
if template[0] != "{" or template[-1] != "}":
@@ -1180,6 +1190,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache"] = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
@@ -1204,16 +1215,22 @@ def _convert_to_message(
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
_message = _create_template_from_message_type(
"human", message, template_format=template_format
)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
_message = _create_template_from_message_type(
message_type_str, template, template_format=template_format
)
else:
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
prompt=PromptTemplate.from_template(
cast(str, template), template_format=template_format
)
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

View File

@@ -10,8 +10,10 @@ from langchain_core.prompts.string import (
StringPromptTemplate,
check_valid_template,
get_template_variables,
mustache_schema,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig
class PromptTemplate(StringPromptTemplate):
@@ -65,12 +67,19 @@ class PromptTemplate(StringPromptTemplate):
template: str
"""The prompt template."""
template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache":
return super().get_input_schema(config)
return mustache_schema(self.template)
def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
@@ -121,6 +130,8 @@ class PromptTemplate(StringPromptTemplate):
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs

View File

@@ -5,10 +5,12 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Set, Tuple, Type
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
@@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]:
return variables
def mustache_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using mustache."""
return mustache.render(template, kwargs)
def mustache_template_vars(
template: str,
) -> Set[str]:
"""Get the variables from a mustache template."""
vars: Set[str] = set()
in_section = False
for type, key in mustache.tokenize(template):
if type == "end":
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
vars.add(key.split(".")[0])
if type == "section":
in_section = True
return vars
Defs = Dict[str, "Defs"]
def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""Get the variables from a mustache template."""
fields = set()
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
for k, v in defs.items()
},
)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
}
@@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else:
raise ValueError(f"Unsupported template format: {template_format}")