add few shot example (#148)

This commit is contained in:
Harrison Chase
2022-11-19 20:32:45 -08:00
committed by GitHub
parent 8869b0ab0e
commit c02eb199b6
68 changed files with 2494 additions and 713 deletions

View File

@@ -2,14 +2,14 @@
import pytest
from langchain.chains.llm import LLMChain
from langchain.prompts.prompt import Prompt
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.fixture
def fake_llm_chain() -> LLMChain:
"""Fake LLM chain for testing purposes."""
prompt = Prompt(input_variables=["bar"], template="This is a {bar}:")
prompt = PromptTemplate(input_variables=["bar"], template="This is a {bar}:")
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")

View File

@@ -4,7 +4,7 @@ import pytest
from langchain.chains.mrkl.base import ChainConfig, MRKLChain, get_action_and_input
from langchain.chains.mrkl.prompt import BASE_TEMPLATE
from langchain.prompts import Prompt
from langchain.prompts import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -66,5 +66,5 @@ def test_from_chains() -> None:
tools=expected_tools_prompt, tool_names=expected_tool_names
)
prompt = mrkl_chain.prompt
assert isinstance(prompt, Prompt)
assert isinstance(prompt, PromptTemplate)
assert prompt.template == expected_template

View File

@@ -9,7 +9,7 @@ from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.prompts.prompt import Prompt
from langchain.prompts.prompt import PromptTemplate
_PAGE_CONTENT = """This is a page about LangChain.
@@ -19,7 +19,7 @@ What isn't there to love about langchain?
Made in 2022."""
_FAKE_PROMPT = Prompt(input_variables=["input"], template="{input}")
_FAKE_PROMPT = PromptTemplate(input_variables=["input"], template="{input}")
class FakeListLLM(LLM):

View File

@@ -0,0 +1 @@
"""Test prompt functionality."""

View File

@@ -0,0 +1,87 @@
"""Test few shot prompt template."""
import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"
)
def test_suffix_only() -> None:
"""Test prompt works with just a suffix."""
suffix = "This is a {foo} test."
input_variables = ["foo"]
prompt = FewShotPromptTemplate(
input_variables=input_variables,
suffix=suffix,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
output = prompt.format(foo="bar")
expected_output = "This is a bar test."
assert output == expected_output
def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
# Test when missing in suffix
template = "This is a {foo} test."
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
# Test when missing in prefix
template = "This is a {foo} test."
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
)
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
def test_few_shot_functionality() -> None:
"""Test that few shot works with examples."""
prefix = "This is a test about {content}."
suffix = "Now you try to talk about {new_content}."
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptTemplate(
suffix=suffix,
prefix=prefix,
input_variables=["content", "new_content"],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(content="animals", new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output

View File

@@ -0,0 +1,48 @@
"""Test functionality related to dynamic prompts."""
import pytest
from langchain.prompts.example_selector.length_based import LengthBasedExampleSelector
from langchain.prompts.prompt import PromptTemplate
EXAMPLES = [
{"question": "Question: who are you?\nAnswer: foo"},
{"question": "Question: who are you?\nAnswer: foo"},
]
@pytest.fixture
def selector() -> LengthBasedExampleSelector:
"""Get length based selector to use in tests."""
prompts = PromptTemplate(input_variables=["question"], template="{question}")
selector = LengthBasedExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
max_length=25,
)
return selector
def test_dynamic_prompt_valid(selector: LengthBasedExampleSelector) -> None:
"""Test dynamic prompt can be successfully constructed from examples."""
short_question = "Short question?"
output = selector.select_examples({"question": short_question})
assert output == EXAMPLES
def test_dynamic_prompt_trims_one_example(selector: LengthBasedExampleSelector) -> None:
"""Test dynamic prompt can trim one example."""
long_question = """I am writing a really long question,
this probably is going to affect the example right?"""
output = selector.select_examples({"question": long_question})
assert output == EXAMPLES[:1]
def test_dynamic_prompt_trims_all_examples(
selector: LengthBasedExampleSelector,
) -> None:
"""Test dynamic prompt can trim all examples."""
longest_question = """This question is super super super,
super super super super super super super super super super super,
super super super super long, this will affect the example right?"""
output = selector.select_examples({"question": longest_question})
assert output == []

View File

@@ -0,0 +1,134 @@
"""Test loading functionality."""
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import PromptTemplate
@contextmanager
def change_directory() -> Iterator:
"""Change the working directory to the right folder."""
origin = Path().absolute()
try:
os.chdir("docs/examples/prompts")
yield
finally:
os.chdir(origin)
def test_loading_from_YAML() -> None:
"""Test loading from yaml file."""
with change_directory():
prompt = load_prompt("simple_prompt.yaml")
expected_prompt = PromptTemplate(
input_variables=["adjective", "content"],
template="Tell me a {adjective} joke about {content}.",
)
assert prompt == expected_prompt
def test_loading_from_JSON() -> None:
"""Test loading from json file."""
with change_directory():
prompt = load_prompt("simple_prompt.json")
expected_prompt = PromptTemplate(
input_variables=["adjective", "content"],
template="Tell me a {adjective} joke about {content}.",
)
assert prompt == expected_prompt
def test_loading_with_template_as_file() -> None:
"""Test loading when the template is a file."""
with change_directory():
prompt = load_prompt("simple_prompt_with_template_file.json")
expected_prompt = PromptTemplate(
input_variables=["adjective", "content"],
template="Tell me a {adjective} joke about {content}.",
)
assert prompt == expected_prompt
def test_loading_few_shot_prompt_from_yaml() -> None:
"""Test loading few shot prompt from yaml."""
with change_directory():
prompt = load_prompt("few_shot_prompt.yaml")
expected_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
assert prompt == expected_prompt
def test_loading_few_shot_prompt_from_json() -> None:
"""Test loading few shot prompt from json."""
with change_directory():
prompt = load_prompt("few_shot_prompt.json")
expected_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
assert prompt == expected_prompt
def test_loading_few_shot_prompt_when_examples_in_config() -> None:
"""Test loading few shot prompt when the examples are in the config."""
with change_directory():
prompt = load_prompt("few_shot_prompt_examples_in.json")
expected_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
assert prompt == expected_prompt
def test_loading_few_shot_prompt_example_prompt() -> None:
"""Test loading few shot when the example prompt is in its own file."""
with change_directory():
prompt = load_prompt("few_shot_prompt_example_prompt.json")
expected_prompt = FewShotPromptTemplate(
input_variables=["adjective"],
prefix="Write antonyms for the following words.",
example_prompt=PromptTemplate(
input_variables=["input", "output"],
template="Input: {input}\nOutput: {output}",
),
examples=[
{"input": "happy", "output": "sad"},
{"input": "tall", "output": "short"},
],
suffix="Input: {adjective}\nOutput:",
)
assert prompt == expected_prompt

View File

@@ -1,14 +1,14 @@
"""Test functionality related to prompts."""
import pytest
from langchain.prompts.prompt import Prompt
from langchain.prompts.prompt import PromptTemplate
def test_prompt_valid() -> None:
"""Test prompts can be constructed."""
template = "This is a {foo} test."
input_variables = ["foo"]
prompt = Prompt(input_variables=input_variables, template=template)
prompt = PromptTemplate(input_variables=input_variables, template=template)
assert prompt.template == template
assert prompt.input_variables == input_variables
@@ -18,7 +18,7 @@ def test_prompt_missing_input_variables() -> None:
template = "This is a {foo} test."
input_variables: list = []
with pytest.raises(ValueError):
Prompt(input_variables=input_variables, template=template)
PromptTemplate(input_variables=input_variables, template=template)
def test_prompt_extra_input_variables() -> None:
@@ -26,7 +26,7 @@ def test_prompt_extra_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
Prompt(input_variables=input_variables, template=template)
PromptTemplate(input_variables=input_variables, template=template)
def test_prompt_wrong_input_variables() -> None:
@@ -34,7 +34,7 @@ def test_prompt_wrong_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
Prompt(input_variables=input_variables, template=template)
PromptTemplate(input_variables=input_variables, template=template)
def test_prompt_from_examples_valid() -> None:
@@ -57,14 +57,16 @@ Answer:"""
"""Question: who are you?\nAnswer: foo""",
"""Question: what are you?\nAnswer: bar""",
]
prompt_from_examples = Prompt.from_examples(
prompt_from_examples = PromptTemplate.from_examples(
examples,
suffix,
input_variables,
example_separator=example_separator,
prefix=prefix,
)
prompt_from_template = Prompt(input_variables=input_variables, template=template)
prompt_from_template = PromptTemplate(
input_variables=input_variables, template=template
)
assert prompt_from_examples.template == prompt_from_template.template
assert prompt_from_examples.input_variables == prompt_from_template.input_variables
@@ -74,7 +76,7 @@ def test_prompt_invalid_template_format() -> None:
template = "This is a {foo} test."
input_variables = ["foo"]
with pytest.raises(ValueError):
Prompt(
PromptTemplate(
input_variables=input_variables, template=template, template_format="bar"
)
@@ -83,5 +85,5 @@ def test_prompt_from_file() -> None:
"""Test prompt can be successfully constructed from a file."""
template_file = "tests/unit_tests/data/prompt_file.txt"
input_variables = ["question"]
prompt = Prompt.from_file(template_file, input_variables)
prompt = PromptTemplate.from_file(template_file, input_variables)
assert prompt.template == "Question: {question}\nAnswer:"

View File

@@ -1,111 +0,0 @@
"""Test functionality related to dynamic prompts."""
from langchain.prompts.dynamic import DynamicPrompt
from langchain.prompts.prompt import Prompt
# FULL TEMPLATES
LONGER_TEMPLATE = """Test Prompt:
Question: who are you?
Answer: foo
Question: what are you?
Answer: bar
Question: {question}
Answer:"""
SHORTER_TEMPLATE = """Test Prompt:
Question: who are you?
Answer: foo
Question: {question}
Answer:"""
SHORTEST_TEMPLATE = """Test Prompt:
Question: {question}
Answer:"""
# DYNAMIC PROMPT COMPONENTS
PREFIX = """Test Prompt:"""
SUFFIX = """Question: {question}\nAnswer:"""
EXAMPLES = [
"""Question: who are you?\nAnswer: foo""",
"""Question: what are you?\nAnswer: bar""",
]
# INPUTS
TEST_LONG_QUESTION = """I am writing a really long question,
this probably is going to affect the example right?"""
TEST_LONGEST_QUESTION = """This question is super super super,
super super super super super super super super super super super,
super super super super long, this will affect the example right?"""
TEST_SHORT_QUESTION = "Short question?"
def test_dynamic_prompt_valid() -> None:
"""Test dynamic prompt can be successfully constructed from examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
)
prompt_cls = Prompt(input_variables=input_variables, template=LONGER_TEMPLATE)
dynamic_prompt_template = dynamic_prompt_cls.format(question="foo?")
prompt_template = prompt_cls.format(question="foo?")
assert dynamic_prompt_template == prompt_template
assert dynamic_prompt_cls.input_variables == prompt_cls.input_variables
def test_dynamic_prompt_trims_one_example() -> None:
"""Test dynamic prompt can trim one example."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONG_QUESTION)
shorter_prompt = SHORTER_TEMPLATE.format(question=TEST_LONG_QUESTION)
assert dynamic_prompt == shorter_prompt
def test_dynamic_prompt_trims_no_examples() -> None:
"""Test dynamic prompt can trim no examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_SHORT_QUESTION)
full_prompt = LONGER_TEMPLATE.format(question=TEST_SHORT_QUESTION)
assert dynamic_prompt == full_prompt
def test_dynamic_prompt_trims_all_examples() -> None:
"""Test dynamic prompt can trim all examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONGEST_QUESTION)
full_prompt = SHORTEST_TEMPLATE.format(question=TEST_LONGEST_QUESTION)
assert dynamic_prompt == full_prompt