mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
add few shot example (#148)
This commit is contained in:
@@ -2,14 +2,14 @@
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.prompt import Prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_chain() -> LLMChain:
|
||||
"""Fake LLM chain for testing purposes."""
|
||||
prompt = Prompt(input_variables=["bar"], template="This is a {bar}:")
|
||||
prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:")
|
||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||
|
||||
|
||||
|
@@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
|
||||
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
|
||||
from langchain.prompts import Prompt
|
||||
from langchain.prompts import PromptTemplate
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@@ -66,5 +66,5 @@ def test_from_chains() -> None:
|
||||
tools=expected_tools_prompt, tool_names=expected_tool_names
|
||||
)
|
||||
prompt = mrkl_chain.prompt
|
||||
assert isinstance(prompt, Prompt)
|
||||
assert isinstance(prompt, PromptTemplate)
|
||||
assert prompt.template == expected_template
|
||||
|
@@ -9,7 +9,7 @@ from langchain.chains.react.base import ReActChain, predict_until_observation
|
||||
from langchain.docstore.base import Docstore
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.prompt import Prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
_PAGE_CONTENT = """This is a page about LangChain.
|
||||
|
||||
@@ -19,7 +19,7 @@ What isn't there to love about langchain?
|
||||
|
||||
Made in 2022."""
|
||||
|
||||
_FAKE_PROMPT = Prompt(input_variables=["input"], template="{input}")
|
||||
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
|
||||
|
||||
|
||||
class FakeListLLM(LLM):
|
||||
|
1
tests/unit_tests/prompts/__init__.py
Normal file
1
tests/unit_tests/prompts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test prompt functionality."""
|
87
tests/unit_tests/prompts/test_few_shot.py
Normal file
87
tests/unit_tests/prompts/test_few_shot.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test few shot prompt template."""
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
EXAMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["question", "answer"], template="{question}: {answer}"
|
||||
)
|
||||
|
||||
|
||||
def test_suffix_only() -> None:
|
||||
"""Test prompt works with just a suffix."""
|
||||
suffix = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
prompt = FewShotPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
suffix=suffix,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
)
|
||||
output = prompt.format(foo="bar")
|
||||
expected_output = "This is a bar test."
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
# Test when missing in suffix
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
)
|
||||
|
||||
# Test when missing in prefix
|
||||
template = "This is a {foo} test."
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix="foo",
|
||||
examples=[],
|
||||
prefix=template,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def test_few_shot_functionality() -> None:
|
||||
"""Test that few shot works with examples."""
|
||||
prefix = "This is a test about {content}."
|
||||
suffix = "Now you try to talk about {new_content}."
|
||||
examples = [
|
||||
{"question": "foo", "answer": "bar"},
|
||||
{"question": "baz", "answer": "foo"},
|
||||
]
|
||||
prompt = FewShotPromptTemplate(
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
input_variables=["content", "new_content"],
|
||||
examples=examples,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
example_separator="\n",
|
||||
)
|
||||
output = prompt.format(content="animals", new_content="party")
|
||||
expected_output = (
|
||||
"This is a test about animals.\n"
|
||||
"foo: bar\n"
|
||||
"baz: foo\n"
|
||||
"Now you try to talk about party."
|
||||
)
|
||||
assert output == expected_output
|
@@ -0,0 +1,48 @@
|
||||
"""Test functionality related to dynamic prompts."""
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.example_selector.length_based import LengthBasedExampleSelector
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
EXAMPLES = [
|
||||
{"question": "Question: who are you?\nAnswer: foo"},
|
||||
{"question": "Question: who are you?\nAnswer: foo"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selector() -> LengthBasedExampleSelector:
|
||||
"""Get length based selector to use in tests."""
|
||||
prompts = PromptTemplate(input_variables=["question"], template="{question}")
|
||||
selector = LengthBasedExampleSelector(
|
||||
examples=EXAMPLES,
|
||||
example_prompt=prompts,
|
||||
max_length=25,
|
||||
)
|
||||
return selector
|
||||
|
||||
|
||||
def test_dynamic_prompt_valid(selector: LengthBasedExampleSelector) -> None:
|
||||
"""Test dynamic prompt can be successfully constructed from examples."""
|
||||
short_question = "Short question?"
|
||||
output = selector.select_examples({"question": short_question})
|
||||
assert output == EXAMPLES
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_one_example(selector: LengthBasedExampleSelector) -> None:
|
||||
"""Test dynamic prompt can trim one example."""
|
||||
long_question = """I am writing a really long question,
|
||||
this probably is going to affect the example right?"""
|
||||
output = selector.select_examples({"question": long_question})
|
||||
assert output == EXAMPLES[:1]
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_all_examples(
|
||||
selector: LengthBasedExampleSelector,
|
||||
) -> None:
|
||||
"""Test dynamic prompt can trim all examples."""
|
||||
longest_question = """This question is super super super,
|
||||
super super super super super super super super super super super,
|
||||
super super super super long, this will affect the example right?"""
|
||||
output = selector.select_examples({"question": longest_question})
|
||||
assert output == []
|
134
tests/unit_tests/prompts/test_loading.py
Normal file
134
tests/unit_tests/prompts/test_loading.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Test loading functionality."""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.loading import load_prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_directory() -> Iterator:
|
||||
"""Change the working directory to the right folder."""
|
||||
origin = Path().absolute()
|
||||
try:
|
||||
os.chdir("docs/examples/prompts")
|
||||
yield
|
||||
finally:
|
||||
os.chdir(origin)
|
||||
|
||||
|
||||
def test_loading_from_YAML() -> None:
|
||||
"""Test loading from yaml file."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("simple_prompt.yaml")
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_from_JSON() -> None:
|
||||
"""Test loading from json file."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("simple_prompt.json")
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("simple_prompt_with_template_file.json")
|
||||
expected_prompt = PromptTemplate(
|
||||
input_variables=["adjective", "content"],
|
||||
template="Tell me a {adjective} joke about {content}.",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_from_yaml() -> None:
|
||||
"""Test loading few shot prompt from yaml."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("few_shot_prompt.yaml")
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
example_prompt=PromptTemplate(
|
||||
input_variables=["input", "output"],
|
||||
template="Input: {input}\nOutput: {output}",
|
||||
),
|
||||
examples=[
|
||||
{"input": "happy", "output": "sad"},
|
||||
{"input": "tall", "output": "short"},
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_from_json() -> None:
|
||||
"""Test loading few shot prompt from json."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("few_shot_prompt.json")
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
example_prompt=PromptTemplate(
|
||||
input_variables=["input", "output"],
|
||||
template="Input: {input}\nOutput: {output}",
|
||||
),
|
||||
examples=[
|
||||
{"input": "happy", "output": "sad"},
|
||||
{"input": "tall", "output": "short"},
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_when_examples_in_config() -> None:
|
||||
"""Test loading few shot prompt when the examples are in the config."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("few_shot_prompt_examples_in.json")
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
example_prompt=PromptTemplate(
|
||||
input_variables=["input", "output"],
|
||||
template="Input: {input}\nOutput: {output}",
|
||||
),
|
||||
examples=[
|
||||
{"input": "happy", "output": "sad"},
|
||||
{"input": "tall", "output": "short"},
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_loading_few_shot_prompt_example_prompt() -> None:
|
||||
"""Test loading few shot when the example prompt is in its own file."""
|
||||
with change_directory():
|
||||
prompt = load_prompt("few_shot_prompt_example_prompt.json")
|
||||
expected_prompt = FewShotPromptTemplate(
|
||||
input_variables=["adjective"],
|
||||
prefix="Write antonyms for the following words.",
|
||||
example_prompt=PromptTemplate(
|
||||
input_variables=["input", "output"],
|
||||
template="Input: {input}\nOutput: {output}",
|
||||
),
|
||||
examples=[
|
||||
{"input": "happy", "output": "sad"},
|
||||
{"input": "tall", "output": "short"},
|
||||
],
|
||||
suffix="Input: {adjective}\nOutput:",
|
||||
)
|
||||
assert prompt == expected_prompt
|
@@ -1,14 +1,14 @@
|
||||
"""Test functionality related to prompts."""
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.prompt import Prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def test_prompt_valid() -> None:
|
||||
"""Test prompts can be constructed."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
prompt = Prompt(input_variables=input_variables, template=template)
|
||||
prompt = PromptTemplate(input_variables=input_variables, template=template)
|
||||
assert prompt.template == template
|
||||
assert prompt.input_variables == input_variables
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_prompt_missing_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
@@ -26,7 +26,7 @@ def test_prompt_extra_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_wrong_input_variables() -> None:
|
||||
@@ -34,7 +34,7 @@ def test_prompt_wrong_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(input_variables=input_variables, template=template)
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
|
||||
|
||||
def test_prompt_from_examples_valid() -> None:
|
||||
@@ -57,14 +57,16 @@ Answer:"""
|
||||
"""Question: who are you?\nAnswer: foo""",
|
||||
"""Question: what are you?\nAnswer: bar""",
|
||||
]
|
||||
prompt_from_examples = Prompt.from_examples(
|
||||
prompt_from_examples = PromptTemplate.from_examples(
|
||||
examples,
|
||||
suffix,
|
||||
input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=prefix,
|
||||
)
|
||||
prompt_from_template = Prompt(input_variables=input_variables, template=template)
|
||||
prompt_from_template = PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
)
|
||||
assert prompt_from_examples.template == prompt_from_template.template
|
||||
assert prompt_from_examples.input_variables == prompt_from_template.input_variables
|
||||
|
||||
@@ -74,7 +76,7 @@ def test_prompt_invalid_template_format() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo"]
|
||||
with pytest.raises(ValueError):
|
||||
Prompt(
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="bar"
|
||||
)
|
||||
|
||||
@@ -83,5 +85,5 @@ def test_prompt_from_file() -> None:
|
||||
"""Test prompt can be successfully constructed from a file."""
|
||||
template_file = "tests/unit_tests/data/prompt_file.txt"
|
||||
input_variables = ["question"]
|
||||
prompt = Prompt.from_file(template_file, input_variables)
|
||||
prompt = PromptTemplate.from_file(template_file, input_variables)
|
||||
assert prompt.template == "Question: {question}\nAnswer:"
|
@@ -1,111 +0,0 @@
|
||||
"""Test functionality related to dynamic prompts."""
|
||||
from langchain.prompts.dynamic import DynamicPrompt
|
||||
from langchain.prompts.prompt import Prompt
|
||||
|
||||
# FULL TEMPLATES
|
||||
LONGER_TEMPLATE = """Test Prompt:
|
||||
|
||||
Question: who are you?
|
||||
Answer: foo
|
||||
|
||||
Question: what are you?
|
||||
Answer: bar
|
||||
|
||||
Question: {question}
|
||||
Answer:"""
|
||||
SHORTER_TEMPLATE = """Test Prompt:
|
||||
|
||||
Question: who are you?
|
||||
Answer: foo
|
||||
|
||||
Question: {question}
|
||||
Answer:"""
|
||||
SHORTEST_TEMPLATE = """Test Prompt:
|
||||
|
||||
Question: {question}
|
||||
Answer:"""
|
||||
|
||||
# DYNAMIC PROMPT COMPONENTS
|
||||
PREFIX = """Test Prompt:"""
|
||||
SUFFIX = """Question: {question}\nAnswer:"""
|
||||
EXAMPLES = [
|
||||
"""Question: who are you?\nAnswer: foo""",
|
||||
"""Question: what are you?\nAnswer: bar""",
|
||||
]
|
||||
|
||||
# INPUTS
|
||||
TEST_LONG_QUESTION = """I am writing a really long question,
|
||||
this probably is going to affect the example right?"""
|
||||
TEST_LONGEST_QUESTION = """This question is super super super,
|
||||
super super super super super super super super super super super,
|
||||
super super super super long, this will affect the example right?"""
|
||||
TEST_SHORT_QUESTION = "Short question?"
|
||||
|
||||
|
||||
def test_dynamic_prompt_valid() -> None:
|
||||
"""Test dynamic prompt can be successfully constructed from examples."""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
dynamic_prompt_cls = DynamicPrompt(
|
||||
examples=EXAMPLES,
|
||||
suffix=SUFFIX,
|
||||
input_variables=input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=PREFIX,
|
||||
)
|
||||
prompt_cls = Prompt(input_variables=input_variables, template=LONGER_TEMPLATE)
|
||||
dynamic_prompt_template = dynamic_prompt_cls.format(question="foo?")
|
||||
prompt_template = prompt_cls.format(question="foo?")
|
||||
assert dynamic_prompt_template == prompt_template
|
||||
assert dynamic_prompt_cls.input_variables == prompt_cls.input_variables
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_one_example() -> None:
|
||||
"""Test dynamic prompt can trim one example."""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
dynamic_prompt_cls = DynamicPrompt(
|
||||
examples=EXAMPLES,
|
||||
suffix=SUFFIX,
|
||||
input_variables=input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=PREFIX,
|
||||
max_length=30,
|
||||
)
|
||||
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONG_QUESTION)
|
||||
shorter_prompt = SHORTER_TEMPLATE.format(question=TEST_LONG_QUESTION)
|
||||
assert dynamic_prompt == shorter_prompt
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_no_examples() -> None:
|
||||
"""Test dynamic prompt can trim no examples."""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
dynamic_prompt_cls = DynamicPrompt(
|
||||
examples=EXAMPLES,
|
||||
suffix=SUFFIX,
|
||||
input_variables=input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=PREFIX,
|
||||
max_length=30,
|
||||
)
|
||||
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_SHORT_QUESTION)
|
||||
full_prompt = LONGER_TEMPLATE.format(question=TEST_SHORT_QUESTION)
|
||||
assert dynamic_prompt == full_prompt
|
||||
|
||||
|
||||
def test_dynamic_prompt_trims_all_examples() -> None:
|
||||
"""Test dynamic prompt can trim all examples."""
|
||||
input_variables = ["question"]
|
||||
example_separator = "\n\n"
|
||||
dynamic_prompt_cls = DynamicPrompt(
|
||||
examples=EXAMPLES,
|
||||
suffix=SUFFIX,
|
||||
input_variables=input_variables,
|
||||
example_separator=example_separator,
|
||||
prefix=PREFIX,
|
||||
max_length=30,
|
||||
)
|
||||
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONGEST_QUESTION)
|
||||
full_prompt = SHORTEST_TEMPLATE.format(question=TEST_LONGEST_QUESTION)
|
||||
assert dynamic_prompt == full_prompt
|
Reference in New Issue
Block a user