mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-11-04 02:03:32 +00:00 
			
		
		
		
	add mako template
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
				
			|||||||
"""Utilities for formatting strings."""
 | 
					"""Utilities for formatting strings."""
 | 
				
			||||||
from string import Formatter
 | 
					from string import Formatter
 | 
				
			||||||
from typing import Any, Mapping, Sequence, Union
 | 
					from typing import Any, Mapping, Sequence, Union
 | 
				
			||||||
 | 
					from mako.template import Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class StrictFormatter(Formatter):
 | 
					class StrictFormatter(Formatter):
 | 
				
			||||||
@@ -28,5 +29,10 @@ class StrictFormatter(Formatter):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
        return super().vformat(format_string, args, kwargs)
 | 
					        return super().vformat(format_string, args, kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def mako_format(self, format_string: str, **kwargs: Any) -> str:
 | 
				
			||||||
 | 
					        """Format a string using mako."""
 | 
				
			||||||
 | 
					        template = Template(format_string)
 | 
				
			||||||
 | 
					        return template.render(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
formatter = StrictFormatter()
 | 
					formatter = StrictFormatter()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ from langchain.formatting import formatter
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
DEFAULT_FORMATTER_MAPPING = {
 | 
					DEFAULT_FORMATTER_MAPPING = {
 | 
				
			||||||
    "f-string": formatter.format,
 | 
					    "f-string": formatter.format,
 | 
				
			||||||
 | 
					    "mako": formatter.mako_format,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ from typing import Any, Dict, List
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from pydantic import BaseModel, Extra, root_validator
 | 
					from pydantic import BaseModel, Extra, root_validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from langchain.prompts.base import (
 | 
					from langchain.prompts.base import (
 | 
				
			||||||
    DEFAULT_FORMATTER_MAPPING,
 | 
					    DEFAULT_FORMATTER_MAPPING,
 | 
				
			||||||
    BasePromptTemplate,
 | 
					    BasePromptTemplate,
 | 
				
			||||||
@@ -106,6 +107,27 @@ class PromptTemplate(BaseModel, BasePromptTemplate):
 | 
				
			|||||||
            template = f.read()
 | 
					            template = f.read()
 | 
				
			||||||
        return cls(input_variables=input_variables, template=template)
 | 
					        return cls(input_variables=input_variables, template=template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def from_mako_template(
 | 
				
			||||||
 | 
					        cls, template_file: str, input_variables: List[str]
 | 
				
			||||||
 | 
					    ) -> "PromptTemplate":
 | 
				
			||||||
 | 
					        """Load a prompt from a mako template file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            template_file: The path to the file containing the prompt template.
 | 
				
			||||||
 | 
					            input_variables: A list of variable names the final prompt template
 | 
				
			||||||
 | 
					                will expect.
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            The prompt loaded from the mako template file.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with open(template_file, "r") as f:
 | 
				
			||||||
 | 
					            template = f.read()
 | 
				
			||||||
 | 
					        return cls(
 | 
				
			||||||
 | 
					            input_variables=input_variables,
 | 
				
			||||||
 | 
					            template=template,
 | 
				
			||||||
 | 
					            template_format="mako",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# For backwards compatibility.
 | 
					# For backwards compatibility.
 | 
				
			||||||
Prompt = PromptTemplate
 | 
					Prompt = PromptTemplate
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										1
									
								
								tests/unit_tests/data/mako_prompt.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tests/unit_tests/data/mako_prompt.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					This is a ${foo} test.
 | 
				
			||||||
@@ -87,3 +87,12 @@ def test_prompt_from_file() -> None:
 | 
				
			|||||||
    input_variables = ["question"]
 | 
					    input_variables = ["question"]
 | 
				
			||||||
    prompt = PromptTemplate.from_file(template_file, input_variables)
 | 
					    prompt = PromptTemplate.from_file(template_file, input_variables)
 | 
				
			||||||
    assert prompt.template == "Question: {question}\nAnswer:"
 | 
					    assert prompt.template == "Question: {question}\nAnswer:"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_mako_template() -> None:
 | 
				
			||||||
 | 
					    """Test mako template can be used."""
 | 
				
			||||||
 | 
					    template_file = "tests/unit_tests/data/mako_prompt.txt"
 | 
				
			||||||
 | 
					    input_variables = ["foo"]
 | 
				
			||||||
 | 
					    prompt = PromptTemplate.from_mako_template(template_file, input_variables)
 | 
				
			||||||
 | 
					    assert prompt.template == "This is a ${foo} test."
 | 
				
			||||||
 | 
					    assert prompt.format(foo="bar") == "This is a bar test."
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user