mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 19:39:58 +00:00
core: mustache prompt templates (#19980)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
@@ -929,6 +930,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
template_format: Literal["f-string", "mustache"] = "f-string",
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@@ -964,7 +966,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
Returns:
|
||||
a chat prompt template
|
||||
"""
|
||||
_messages = [_convert_to_message(message) for message in messages]
|
||||
_messages = [
|
||||
_convert_to_message(message, template_format) for message in messages
|
||||
]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
@@ -1121,7 +1125,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
|
||||
|
||||
def _create_template_from_message_type(
|
||||
message_type: str, template: Union[str, list]
|
||||
message_type: str,
|
||||
template: Union[str, list],
|
||||
template_format: Literal["f-string", "mustache"] = "f-string",
|
||||
) -> BaseMessagePromptTemplate:
|
||||
"""Create a message prompt template from a message type and template string.
|
||||
|
||||
@@ -1134,12 +1140,16 @@ def _create_template_from_message_type(
|
||||
"""
|
||||
if message_type in ("human", "user"):
|
||||
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
|
||||
template
|
||||
template, template_format=template_format
|
||||
)
|
||||
elif message_type in ("ai", "assistant"):
|
||||
message = AIMessagePromptTemplate.from_template(cast(str, template))
|
||||
message = AIMessagePromptTemplate.from_template(
|
||||
cast(str, template), template_format=template_format
|
||||
)
|
||||
elif message_type == "system":
|
||||
message = SystemMessagePromptTemplate.from_template(cast(str, template))
|
||||
message = SystemMessagePromptTemplate.from_template(
|
||||
cast(str, template), template_format=template_format
|
||||
)
|
||||
elif message_type == "placeholder":
|
||||
if isinstance(template, str):
|
||||
if template[0] != "{" or template[-1] != "}":
|
||||
@@ -1180,6 +1190,7 @@ def _create_template_from_message_type(
|
||||
|
||||
def _convert_to_message(
|
||||
message: MessageLikeRepresentation,
|
||||
template_format: Literal["f-string", "mustache"] = "f-string",
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
@@ -1204,16 +1215,22 @@ def _convert_to_message(
|
||||
elif isinstance(message, BaseMessage):
|
||||
_message = message
|
||||
elif isinstance(message, str):
|
||||
_message = _create_template_from_message_type("human", message)
|
||||
_message = _create_template_from_message_type(
|
||||
"human", message, template_format=template_format
|
||||
)
|
||||
elif isinstance(message, tuple):
|
||||
if len(message) != 2:
|
||||
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
|
||||
message_type_str, template = message
|
||||
if isinstance(message_type_str, str):
|
||||
_message = _create_template_from_message_type(message_type_str, template)
|
||||
_message = _create_template_from_message_type(
|
||||
message_type_str, template, template_format=template_format
|
||||
)
|
||||
else:
|
||||
_message = message_type_str(
|
||||
prompt=PromptTemplate.from_template(cast(str, template))
|
||||
prompt=PromptTemplate.from_template(
|
||||
cast(str, template), template_format=template_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {type(message)}")
|
||||
|
@@ -10,8 +10,10 @@ from langchain_core.prompts.string import (
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
mustache_schema,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
class PromptTemplate(StringPromptTemplate):
|
||||
@@ -65,12 +67,19 @@ class PromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Literal["f-string", "jinja2"] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
|
||||
"""The format of the prompt template.
|
||||
Options are: 'f-string', 'mustache', 'jinja2'."""
|
||||
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
||||
if self.template_format != "mustache":
|
||||
return super().get_input_schema(config)
|
||||
|
||||
return mustache_schema(self.template)
|
||||
|
||||
def __add__(self, other: Any) -> PromptTemplate:
|
||||
"""Override the + operator to allow for combining prompt templates."""
|
||||
# Allow for easy combining
|
||||
@@ -121,6 +130,8 @@ class PromptTemplate(StringPromptTemplate):
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that template and input variables are consistent."""
|
||||
if values["validate_template"]:
|
||||
if values["template_format"] == "mustache":
|
||||
raise ValueError("Mustache templates cannot be validated.")
|
||||
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
|
@@ -5,10 +5,12 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Type
|
||||
|
||||
import langchain_core.utils.mustache as mustache
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, create_model
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.formatting import formatter
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
@@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
return variables
|
||||
|
||||
|
||||
def mustache_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using mustache."""
|
||||
return mustache.render(template, kwargs)
|
||||
|
||||
|
||||
def mustache_template_vars(
|
||||
template: str,
|
||||
) -> Set[str]:
|
||||
"""Get the variables from a mustache template."""
|
||||
vars: Set[str] = set()
|
||||
in_section = False
|
||||
for type, key in mustache.tokenize(template):
|
||||
if type == "end":
|
||||
in_section = False
|
||||
elif in_section:
|
||||
continue
|
||||
elif type in ("variable", "section") and key != ".":
|
||||
vars.add(key.split(".")[0])
|
||||
if type == "section":
|
||||
in_section = True
|
||||
return vars
|
||||
|
||||
|
||||
Defs = Dict[str, "Defs"]
|
||||
|
||||
|
||||
def mustache_schema(
|
||||
template: str,
|
||||
) -> Type[BaseModel]:
|
||||
"""Get the variables from a mustache template."""
|
||||
fields = set()
|
||||
prefix: Tuple[str, ...] = ()
|
||||
for type, key in mustache.tokenize(template):
|
||||
if key == ".":
|
||||
continue
|
||||
if type == "end":
|
||||
prefix = prefix[: -key.count(".")]
|
||||
elif type == "section":
|
||||
prefix = prefix + tuple(key.split("."))
|
||||
elif type == "variable":
|
||||
fields.add(prefix + tuple(key.split(".")))
|
||||
defs: Defs = {} # None means leaf node
|
||||
while fields:
|
||||
field = fields.pop()
|
||||
current = defs
|
||||
for part in field[:-1]:
|
||||
current = current.setdefault(part, {})
|
||||
current[field[-1]] = {}
|
||||
return _create_model_recursive("PromptInput", defs)
|
||||
|
||||
|
||||
def _create_model_recursive(name: str, defs: Defs) -> Type:
|
||||
return create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
**{
|
||||
k: (_create_model_recursive(k, v), None) if v else (str, None)
|
||||
for k, v in defs.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"mustache": mustache_formatter,
|
||||
"jinja2": jinja2_formatter,
|
||||
}
|
||||
|
||||
@@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
elif template_format == "mustache":
|
||||
input_variables = mustache_template_vars(template)
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
|
||||
|
Reference in New Issue
Block a user