From 5f2dea2b2059084630275402d2e0329ab7705efc Mon Sep 17 00:00:00 2001 From: Shenhai Ran Date: Tue, 16 Jul 2024 15:35:09 +0200 Subject: [PATCH] core[patch]: Add encoding options when create prompt template from a file (#24054) - Uses default utf-8 encoding for loading prompt templates from file --------- Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/prompts/prompt.py | 5 +++- .../tests/unit_tests/prompts/test_prompt.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index ef4084e83c1..cfa3836ac7c 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -214,6 +214,7 @@ class PromptTemplate(StringPromptTemplate): cls, template_file: Union[str, Path], input_variables: Optional[List[str]] = None, + encoding: Optional[str] = None, **kwargs: Any, ) -> PromptTemplate: """Load a prompt from a file. @@ -222,13 +223,15 @@ class PromptTemplate(StringPromptTemplate): template_file: The path to the file containing the prompt template. input_variables: [DEPRECATED] A list of variable names the final prompt template will expect. Defaults to None. + encoding: The encoding system for opening the template file. + If not provided, will use the OS default. input_variables is ignored as from_file now delegates to from_template(). Returns: The prompt loaded from the file. """ - with open(str(template_file), "r") as f: + with open(str(template_file), "r", encoding=encoding) as f: template = f.read() if input_variables: warnings.warn( diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index ecb30b2c79c..e166ef29ac4 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -18,6 +18,29 @@ def test_prompt_valid() -> None: assert prompt.input_variables == input_variables +def test_from_file_encoding() -> None: + """Test that we can load a template from a file with a non utf-8 encoding.""" + template = "This is a {foo} test with special character €." + input_variables = ["foo"] + + # First write to a file using CP-1252 encoding. + from tempfile import NamedTemporaryFile + + with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f: + f.write(template) + f.flush() + file_name = f.name + + # Now read from the file using CP-1252 encoding and test + prompt = PromptTemplate.from_file(file_name, encoding="cp1252") + assert prompt.template == template + assert prompt.input_variables == input_variables + + # Now read from the file using UTF-8 encoding and test + with pytest.raises(UnicodeDecodeError): + PromptTemplate.from_file(file_name, encoding="utf-8") + + def test_prompt_from_template() -> None: """Test prompts can be constructed from a template.""" # Single input variable.