mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
Harrison/ngram example (#846)
Co-authored-by: Sean Spriggens <ssprigge@syr.edu>
This commit is contained in:
parent
0de55048b7
commit
23d5f64bda
@ -23,7 +23,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "8244ff60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -81,7 +81,7 @@
|
||||
" template=\"Input: {input}\\nOutput: {output}\",\n",
|
||||
")\n",
|
||||
"example_selector = LengthBasedExampleSelector(\n",
|
||||
" # These are the examples is has available to choose from.\n",
|
||||
" # These are the examples it has available to choose from.\n",
|
||||
" examples=examples, \n",
|
||||
" # This is the PromptTemplate being used to format the examples.\n",
|
||||
" example_prompt=example_prompt, \n",
|
||||
@ -439,10 +439,242 @@
|
||||
"print(similar_prompt.format(adjective=\"worried\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4aaeed2f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## NGram Overlap ExampleSelector\n",
|
||||
"\n",
|
||||
"The NGramOverlapExampleSelector selects and orders examples based on which examples are most similar to the input, according to an ngram overlap score. The ngram overlap score is a float between 0.0 and 1.0, inclusive. \n",
|
||||
"\n",
|
||||
"The selector allows for a threshold score to be set. Examples with an ngram overlap score less than or equal to the threshold are excluded. The threshold is set to -1.0, by default, so will not exclude any examples, only reorder them. Setting the threshold to 0.0 will exclude examples that have no ngram overlaps with the input.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "9cbc0acc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4f318f4b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# These are examples of a fictional translation task.\n",
|
||||
"examples = [\n",
|
||||
" {\"input\": \"See Spot run.\", \"output\": \"Ver correr a Spot.\"},\n",
|
||||
" {\"input\": \"My dog barks.\", \"output\": \"Mi perro ladra.\"},\n",
|
||||
" {\"input\": \"Spot can run.\", \"output\": \"Spot puede correr.\"},\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "bf75e0fe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"example_prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"input\", \"output\"],\n",
|
||||
" template=\"Input: {input}\\nOutput: {output}\",\n",
|
||||
")\n",
|
||||
"example_selector = NGramOverlapExampleSelector(\n",
|
||||
" # These are the examples it has available to choose from.\n",
|
||||
" examples=examples, \n",
|
||||
" # This is the PromptTemplate being used to format the examples.\n",
|
||||
" example_prompt=example_prompt, \n",
|
||||
" # This is the threshold, at which selector stops.\n",
|
||||
" # It is set to -1.0 by default.\n",
|
||||
" threshold=-1.0,\n",
|
||||
" # For negative threshold:\n",
|
||||
" # Selector sorts examples by ngram overlap score, and excludes none.\n",
|
||||
" # For threshold greater than 1.0:\n",
|
||||
" # Selector excludes all examples, and returns an empty list.\n",
|
||||
" # For threshold equal to 0.0:\n",
|
||||
" # Selector sorts examples by ngram overlap score,\n",
|
||||
" # and excludes those with no ngram overlap with input.\n",
|
||||
")\n",
|
||||
"dynamic_prompt = FewShotPromptTemplate(\n",
|
||||
" # We provide an ExampleSelector instead of examples.\n",
|
||||
" example_selector=example_selector,\n",
|
||||
" example_prompt=example_prompt,\n",
|
||||
" prefix=\"Give the Spanish translation of every input\",\n",
|
||||
" suffix=\"Input: {sentence}\\nOutput:\", \n",
|
||||
" input_variables=[\"sentence\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "83fb218a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the Spanish translation of every input\n",
|
||||
"\n",
|
||||
"Input: Spot can run.\n",
|
||||
"Output: Spot puede correr.\n",
|
||||
"\n",
|
||||
"Input: See Spot run.\n",
|
||||
"Output: Ver correr a Spot.\n",
|
||||
"\n",
|
||||
"Input: My dog barks.\n",
|
||||
"Output: Mi perro ladra.\n",
|
||||
"\n",
|
||||
"Input: Spot can run fast.\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# An example input with large ngram overlap with \"Spot can run.\"\n",
|
||||
"# and no overlap with \"My dog barks.\"\n",
|
||||
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "485f5307",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the Spanish translation of every input\n",
|
||||
"\n",
|
||||
"Input: Spot can run.\n",
|
||||
"Output: Spot puede correr.\n",
|
||||
"\n",
|
||||
"Input: See Spot run.\n",
|
||||
"Output: Ver correr a Spot.\n",
|
||||
"\n",
|
||||
"Input: Spot plays fetch.\n",
|
||||
"Output: Spot juega a buscar.\n",
|
||||
"\n",
|
||||
"Input: My dog barks.\n",
|
||||
"Output: Mi perro ladra.\n",
|
||||
"\n",
|
||||
"Input: Spot can run fast.\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# You can add examples to NGramOverlapExampleSelector as well.\n",
|
||||
"new_example = {\"input\": \"Spot plays fetch.\", \"output\": \"Spot juega a buscar.\"}\n",
|
||||
"\n",
|
||||
"example_selector.add_example(new_example)\n",
|
||||
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "606ce697",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the Spanish translation of every input\n",
|
||||
"\n",
|
||||
"Input: Spot can run.\n",
|
||||
"Output: Spot puede correr.\n",
|
||||
"\n",
|
||||
"Input: See Spot run.\n",
|
||||
"Output: Ver correr a Spot.\n",
|
||||
"\n",
|
||||
"Input: Spot plays fetch.\n",
|
||||
"Output: Spot juega a buscar.\n",
|
||||
"\n",
|
||||
"Input: Spot can run fast.\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# You can set a threshold at which examples are excluded.\n",
|
||||
"# For example, setting threshold equal to 0.0\n",
|
||||
"# excludes examples with no ngram overlaps with input.\n",
|
||||
"# Since \"My dog barks.\" has no ngram overlaps with \"Spot can run fast.\"\n",
|
||||
"# it is excluded.\n",
|
||||
"example_selector.threshold=0.0\n",
|
||||
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 87,
|
||||
"id": "7f8d72f7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the Spanish translation of every input\n",
|
||||
"\n",
|
||||
"Input: Spot can run.\n",
|
||||
"Output: Spot puede correr.\n",
|
||||
"\n",
|
||||
"Input: Spot plays fetch.\n",
|
||||
"Output: Spot juega a buscar.\n",
|
||||
"\n",
|
||||
"Input: Spot can play fetch.\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Setting small nonzero threshold\n",
|
||||
"example_selector.threshold=0.09\n",
|
||||
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 88,
|
||||
"id": "09633aa8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Give the Spanish translation of every input\n",
|
||||
"\n",
|
||||
"Input: Spot can play fetch.\n",
|
||||
"Output:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Setting threshold greater than 1.0\n",
|
||||
"example_selector.threshold=1.0+1e-9\n",
|
||||
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c746d6f4",
|
||||
"id": "39f30097",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
112
langchain/prompts/example_selector/ngram_overlap.py
Normal file
112
langchain/prompts/example_selector/ngram_overlap.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
|
||||
"""Compute ngram overlap score of source and example as sentence_bleu score.
|
||||
|
||||
Use sentence_bleu with method1 smoothing function and auto reweighting.
|
||||
Return float value between 0.0 and 1.0 inclusive.
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
from nltk.translate.bleu_score import ( # type: ignore
|
||||
SmoothingFunction,
|
||||
sentence_bleu,
|
||||
)
|
||||
|
||||
hypotheses = source[0].split()
|
||||
references = [s.split() for s in example]
|
||||
|
||||
return float(
|
||||
sentence_bleu(
|
||||
references,
|
||||
hypotheses,
|
||||
smoothing_function=SmoothingFunction().method1,
|
||||
auto_reweigh=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||
|
||||
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||
https://aclanthology.org/P02-1040.pdf
|
||||
"""
|
||||
|
||||
examples: List[dict]
|
||||
"""A list of the examples that the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""Prompt template used to format the examples."""
|
||||
|
||||
threshold: float = -1.0
|
||||
"""Threshold at which algorithm stops. Set to -1.0 by default.
|
||||
|
||||
For negative threshold:
|
||||
select_examples sorts examples by ngram_overlap_score, but excludes none.
|
||||
For threshold greater than 1.0:
|
||||
select_examples excludes all examples, and returns an empty list.
|
||||
For threshold equal to 0.0:
|
||||
select_examples sorts examples by ngram_overlap_score,
|
||||
and excludes examples with no ngram overlap with input.
|
||||
"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_dependencies(cls, values: Dict) -> Dict:
|
||||
"""Check that valid dependencies exist."""
|
||||
try:
|
||||
from nltk.translate.bleu_score import ( # noqa: disable=F401
|
||||
SmoothingFunction,
|
||||
sentence_bleu,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Not all the correct dependencies for this ExampleSelect exist"
|
||||
) from e
|
||||
|
||||
return values
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to list."""
|
||||
self.examples.append(example)
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Return list of examples sorted by ngram_overlap_score with input.
|
||||
|
||||
Descending order.
|
||||
Excludes any examples with ngram_overlap_score less than or equal to threshold.
|
||||
"""
|
||||
inputs = list(input_variables.values())
|
||||
examples = []
|
||||
k = len(self.examples)
|
||||
score = [0.0] * k
|
||||
first_prompt_template_key = self.example_prompt.input_variables[0]
|
||||
|
||||
for i in range(k):
|
||||
score[i] = ngram_overlap_score(
|
||||
inputs, [self.examples[i][first_prompt_template_key]]
|
||||
)
|
||||
|
||||
while True:
|
||||
arg_max = np.argmax(score)
|
||||
if (score[arg_max] < self.threshold) or abs(
|
||||
score[arg_max] - self.threshold
|
||||
) < 1e-9:
|
||||
break
|
||||
|
||||
examples.append(self.examples[arg_max])
|
||||
score[arg_max] = self.threshold - 1.0
|
||||
|
||||
return examples
|
@ -0,0 +1,73 @@
|
||||
"""Test functionality related to ngram overlap based selector."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.example_selector.ngram_overlap import (
|
||||
NGramOverlapExampleSelector,
|
||||
ngram_overlap_score,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
EXAMPLES = [
|
||||
{"input": "See Spot run.", "output": "foo1"},
|
||||
{"input": "My dog barks.", "output": "foo2"},
|
||||
{"input": "Spot can run.", "output": "foo3"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def selector() -> NGramOverlapExampleSelector:
|
||||
"""Get ngram overlap based selector to use in tests."""
|
||||
prompts = PromptTemplate(
|
||||
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
|
||||
)
|
||||
selector = NGramOverlapExampleSelector(
|
||||
examples=EXAMPLES,
|
||||
example_prompt=prompts,
|
||||
)
|
||||
return selector
|
||||
|
||||
|
||||
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can select examples."""
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Test NGramOverlapExampleSelector can add an example."""
|
||||
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
|
||||
selector.add_example(new_example)
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
|
||||
|
||||
|
||||
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
|
||||
selector.threshold = 0.0
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == [EXAMPLES[2], EXAMPLES[0]]
|
||||
|
||||
|
||||
def test_selector_threshold_more_than_one(
|
||||
selector: NGramOverlapExampleSelector,
|
||||
) -> None:
|
||||
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
sentence = "Spot can run."
|
||||
output = selector.select_examples({"input": sentence})
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
|
||||
"""Tests that ngram_overlap_score returns correct values."""
|
||||
selector.threshold = 1.0 + 1e-9
|
||||
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
|
||||
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
|
||||
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
|
||||
|
||||
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
|
||||
assert check == [True, True, True]
|
Loading…
Reference in New Issue
Block a user