mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
add few shot example (#148)
This commit is contained in:
89
tests/unit_tests/prompts/test_prompt.py
Normal file
89
tests/unit_tests/prompts/test_prompt.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Test functionality related to prompts."""
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
"""Test prompts can be constructed."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||
assert prompt.template == template
|
||||
assert prompt.input_variables == input_variables
|
||||
|
||||
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_from_examples_valid() -> None:
|
||||
"""Test prompt can be successfully constructed from examples."""
|
||||
template = """Test Prompt:
|
||||
|
||||
Question: who are you?
|
||||
Answer: foo
|
||||
|
||||
Question: what are you?
|
||||
Answer: bar
|
||||
|
||||
Question: {question}
|
||||
Answer:"""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
prefix = """Test Prompt:"""
|
||||
suffix = """Question: {question}\nAnswer:"""
|
||||
examples = [
|
||||
"""Question: who are you?\nAnswer: foo""",
|
||||
"""Question: what are you?\nAnswer: bar""",
|
||||
]
|
||||
prompt_from_examples = PromptTemplate.from_examples(
|
||||
examples,
|
||||
suffix,
|
||||
input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=prefix,
|
||||
)
|
||||
prompt_from_template = PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
)
|
||||
assert prompt_from_examples.template == prompt_from_template.template
|
||||
assert prompt_from_examples.input_variables == prompt_from_template.input_variables
|
||||
|
||||
|
||||
def test_prompt_invalid_template_format() -> None:
|
||||
"""Test initializing a prompt with invalid template format."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="bar"
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_from_file() -> None:
|
||||
"""Test prompt can be successfully constructed from a file."""
|
||||
template_file = "tests/unit_tests/data/prompt_file.txt"
|
||||
input_variables = ["question"]
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
assert prompt.template == "Question: {question}\nAnswer:"
|
Reference in New Issue
Block a user