mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
add mako template
This commit is contained in:
parent
a19ad935b3
commit
5b48ab8db3
@ -1,6 +1,7 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""Utilities for formatting strings."""
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Mapping, Sequence, Union
|
from typing import Any, Mapping, Sequence, Union
|
||||||
|
from mako.template import Template
|
||||||
|
|
||||||
|
|
||||||
class StrictFormatter(Formatter):
|
class StrictFormatter(Formatter):
|
||||||
@ -28,5 +29,10 @@ class StrictFormatter(Formatter):
|
|||||||
)
|
)
|
||||||
return super().vformat(format_string, args, kwargs)
|
return super().vformat(format_string, args, kwargs)
|
||||||
|
|
||||||
|
def mako_format(self, format_string: str, **kwargs: Any) -> str:
|
||||||
|
"""Format a string using mako."""
|
||||||
|
template = Template(format_string)
|
||||||
|
return template.render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
formatter = StrictFormatter()
|
formatter = StrictFormatter()
|
||||||
|
@ -6,6 +6,7 @@ from langchain.formatting import formatter
|
|||||||
|
|
||||||
DEFAULT_FORMATTER_MAPPING = {
|
DEFAULT_FORMATTER_MAPPING = {
|
||||||
"f-string": formatter.format,
|
"f-string": formatter.format,
|
||||||
|
"mako": formatter.mako_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
@ -106,6 +107,27 @@ class PromptTemplate(BaseModel, BasePromptTemplate):
|
|||||||
template = f.read()
|
template = f.read()
|
||||||
return cls(input_variables=input_variables, template=template)
|
return cls(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_mako_template(
|
||||||
|
cls, template_file: str, input_variables: List[str]
|
||||||
|
) -> "PromptTemplate":
|
||||||
|
"""Load a prompt from a mako template file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_file: The path to the file containing the prompt template.
|
||||||
|
input_variables: A list of variable names the final prompt template
|
||||||
|
will expect.
|
||||||
|
Returns:
|
||||||
|
The prompt loaded from the mako template file.
|
||||||
|
"""
|
||||||
|
with open(template_file, "r") as f:
|
||||||
|
template = f.read()
|
||||||
|
return cls(
|
||||||
|
input_variables=input_variables,
|
||||||
|
template=template,
|
||||||
|
template_format="mako",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
Prompt = PromptTemplate
|
Prompt = PromptTemplate
|
||||||
|
1
tests/unit_tests/data/mako_prompt.txt
Normal file
1
tests/unit_tests/data/mako_prompt.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
This is a ${foo} test.
|
@ -87,3 +87,12 @@ def test_prompt_from_file() -> None:
|
|||||||
input_variables = ["question"]
|
input_variables = ["question"]
|
||||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||||
assert prompt.template == "Question: {question}\nAnswer:"
|
assert prompt.template == "Question: {question}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mako_template() -> None:
|
||||||
|
"""Test mako template can be used."""
|
||||||
|
template_file = "tests/unit_tests/data/mako_prompt.txt"
|
||||||
|
input_variables = ["foo"]
|
||||||
|
prompt = PromptTemplate.from_mako_template(template_file, input_variables)
|
||||||
|
assert prompt.template == "This is a ${foo} test."
|
||||||
|
assert prompt.format(foo="bar") == "This is a bar test."
|
||||||
|
Loading…
Reference in New Issue
Block a user