core[patch]: Add encoding options when create prompt template from a file (#24054)

- Uses default utf-8 encoding for loading prompt templates from file

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Shenhai Ran 2024-07-16 15:35:09 +02:00 committed by GitHub
parent 69b1603173
commit 5f2dea2b20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 1 deletions

View File

@ -214,6 +214,7 @@ class PromptTemplate(StringPromptTemplate):
cls,
template_file: Union[str, Path],
input_variables: Optional[List[str]] = None,
encoding: Optional[str] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt from a file.
@ -222,13 +223,15 @@ class PromptTemplate(StringPromptTemplate):
template_file: The path to the file containing the prompt template.
input_variables: [DEPRECATED] A list of variable names the final prompt
template will expect. Defaults to None.
encoding: The encoding system for opening the template file.
If not provided, will use the OS default.
input_variables is ignored as from_file now delegates to from_template().
Returns:
The prompt loaded from the file.
"""
with open(str(template_file), "r") as f:
with open(str(template_file), "r", encoding=encoding) as f:
template = f.read()
if input_variables:
warnings.warn(

View File

@ -18,6 +18,29 @@ def test_prompt_valid() -> None:
assert prompt.input_variables == input_variables
def test_from_file_encoding() -> None:
"""Test that we can load a template from a file with a non utf-8 encoding."""
template = "This is a {foo} test with special character €."
input_variables = ["foo"]
# First write to a file using CP-1252 encoding.
from tempfile import NamedTemporaryFile
with NamedTemporaryFile(delete=True, mode="w", encoding="cp1252") as f:
f.write(template)
f.flush()
file_name = f.name
# Now read from the file using CP-1252 encoding and test
prompt = PromptTemplate.from_file(file_name, encoding="cp1252")
assert prompt.template == template
assert prompt.input_variables == input_variables
# Now read from the file using UTF-8 encoding and test
with pytest.raises(UnicodeDecodeError):
PromptTemplate.from_file(file_name, encoding="utf-8")
def test_prompt_from_template() -> None:
"""Test prompts can be constructed from a template."""
# Single input variable.