mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Harrison/ngram example (#846)
Co-authored-by: Sean Spriggens <ssprigge@syr.edu>
This commit is contained in:
parent
0de55048b7
commit
23d5f64bda
@ -23,7 +23,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 1,
|
||||||
"id": "8244ff60",
|
"id": "8244ff60",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -81,7 +81,7 @@
|
|||||||
" template=\"Input: {input}\\nOutput: {output}\",\n",
|
" template=\"Input: {input}\\nOutput: {output}\",\n",
|
||||||
")\n",
|
")\n",
|
||||||
"example_selector = LengthBasedExampleSelector(\n",
|
"example_selector = LengthBasedExampleSelector(\n",
|
||||||
" # These are the examples is has available to choose from.\n",
|
" # These are the examples it has available to choose from.\n",
|
||||||
" examples=examples, \n",
|
" examples=examples, \n",
|
||||||
" # This is the PromptTemplate being used to format the examples.\n",
|
" # This is the PromptTemplate being used to format the examples.\n",
|
||||||
" example_prompt=example_prompt, \n",
|
" example_prompt=example_prompt, \n",
|
||||||
@ -439,10 +439,242 @@
|
|||||||
"print(similar_prompt.format(adjective=\"worried\"))"
|
"print(similar_prompt.format(adjective=\"worried\"))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4aaeed2f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## NGram Overlap ExampleSelector\n",
|
||||||
|
"\n",
|
||||||
|
"The NGramOverlapExampleSelector selects and orders examples based on which examples are most similar to the input, according to an ngram overlap score. The ngram overlap score is a float between 0.0 and 1.0, inclusive. \n",
|
||||||
|
"\n",
|
||||||
|
"The selector allows for a threshold score to be set. Examples with an ngram overlap score less than or equal to the threshold are excluded. The threshold is set to -1.0, by default, so will not exclude any examples, only reorder them. Setting the threshold to 0.0 will exclude examples that have no ngram overlaps with the input.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "9cbc0acc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "4f318f4b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# These are examples of a fictional translation task.\n",
|
||||||
|
"examples = [\n",
|
||||||
|
" {\"input\": \"See Spot run.\", \"output\": \"Ver correr a Spot.\"},\n",
|
||||||
|
" {\"input\": \"My dog barks.\", \"output\": \"Mi perro ladra.\"},\n",
|
||||||
|
" {\"input\": \"Spot can run.\", \"output\": \"Spot puede correr.\"},\n",
|
||||||
|
"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "bf75e0fe",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"example_prompt = PromptTemplate(\n",
|
||||||
|
" input_variables=[\"input\", \"output\"],\n",
|
||||||
|
" template=\"Input: {input}\\nOutput: {output}\",\n",
|
||||||
|
")\n",
|
||||||
|
"example_selector = NGramOverlapExampleSelector(\n",
|
||||||
|
" # These are the examples it has available to choose from.\n",
|
||||||
|
" examples=examples, \n",
|
||||||
|
" # This is the PromptTemplate being used to format the examples.\n",
|
||||||
|
" example_prompt=example_prompt, \n",
|
||||||
|
" # This is the threshold, at which selector stops.\n",
|
||||||
|
" # It is set to -1.0 by default.\n",
|
||||||
|
" threshold=-1.0,\n",
|
||||||
|
" # For negative threshold:\n",
|
||||||
|
" # Selector sorts examples by ngram overlap score, and excludes none.\n",
|
||||||
|
" # For threshold greater than 1.0:\n",
|
||||||
|
" # Selector excludes all examples, and returns an empty list.\n",
|
||||||
|
" # For threshold equal to 0.0:\n",
|
||||||
|
" # Selector sorts examples by ngram overlap score,\n",
|
||||||
|
" # and excludes those with no ngram overlap with input.\n",
|
||||||
|
")\n",
|
||||||
|
"dynamic_prompt = FewShotPromptTemplate(\n",
|
||||||
|
" # We provide an ExampleSelector instead of examples.\n",
|
||||||
|
" example_selector=example_selector,\n",
|
||||||
|
" example_prompt=example_prompt,\n",
|
||||||
|
" prefix=\"Give the Spanish translation of every input\",\n",
|
||||||
|
" suffix=\"Input: {sentence}\\nOutput:\", \n",
|
||||||
|
" input_variables=[\"sentence\"],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "83fb218a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Give the Spanish translation of every input\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run.\n",
|
||||||
|
"Output: Spot puede correr.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: See Spot run.\n",
|
||||||
|
"Output: Ver correr a Spot.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: My dog barks.\n",
|
||||||
|
"Output: Mi perro ladra.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run fast.\n",
|
||||||
|
"Output:\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# An example input with large ngram overlap with \"Spot can run.\"\n",
|
||||||
|
"# and no overlap with \"My dog barks.\"\n",
|
||||||
|
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "485f5307",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Give the Spanish translation of every input\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run.\n",
|
||||||
|
"Output: Spot puede correr.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: See Spot run.\n",
|
||||||
|
"Output: Ver correr a Spot.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot plays fetch.\n",
|
||||||
|
"Output: Spot juega a buscar.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: My dog barks.\n",
|
||||||
|
"Output: Mi perro ladra.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run fast.\n",
|
||||||
|
"Output:\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# You can add examples to NGramOverlapExampleSelector as well.\n",
|
||||||
|
"new_example = {\"input\": \"Spot plays fetch.\", \"output\": \"Spot juega a buscar.\"}\n",
|
||||||
|
"\n",
|
||||||
|
"example_selector.add_example(new_example)\n",
|
||||||
|
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "606ce697",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Give the Spanish translation of every input\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run.\n",
|
||||||
|
"Output: Spot puede correr.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: See Spot run.\n",
|
||||||
|
"Output: Ver correr a Spot.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot plays fetch.\n",
|
||||||
|
"Output: Spot juega a buscar.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run fast.\n",
|
||||||
|
"Output:\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# You can set a threshold at which examples are excluded.\n",
|
||||||
|
"# For example, setting threshold equal to 0.0\n",
|
||||||
|
"# excludes examples with no ngram overlaps with input.\n",
|
||||||
|
"# Since \"My dog barks.\" has no ngram overlaps with \"Spot can run fast.\"\n",
|
||||||
|
"# it is excluded.\n",
|
||||||
|
"example_selector.threshold=0.0\n",
|
||||||
|
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 87,
|
||||||
|
"id": "7f8d72f7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Give the Spanish translation of every input\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can run.\n",
|
||||||
|
"Output: Spot puede correr.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot plays fetch.\n",
|
||||||
|
"Output: Spot juega a buscar.\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can play fetch.\n",
|
||||||
|
"Output:\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Setting small nonzero threshold\n",
|
||||||
|
"example_selector.threshold=0.09\n",
|
||||||
|
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 88,
|
||||||
|
"id": "09633aa8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Give the Spanish translation of every input\n",
|
||||||
|
"\n",
|
||||||
|
"Input: Spot can play fetch.\n",
|
||||||
|
"Output:\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Setting threshold greater than 1.0\n",
|
||||||
|
"example_selector.threshold=1.0+1e-9\n",
|
||||||
|
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c746d6f4",
|
"id": "39f30097",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
|
112
langchain/prompts/example_selector/ngram_overlap.py
Normal file
112
langchain/prompts/example_selector/ngram_overlap.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||||
|
|
||||||
|
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||||
|
https://aclanthology.org/P02-1040.pdf
|
||||||
|
"""
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
|
||||||
|
"""Compute ngram overlap score of source and example as sentence_bleu score.
|
||||||
|
|
||||||
|
Use sentence_bleu with method1 smoothing function and auto reweighting.
|
||||||
|
Return float value between 0.0 and 1.0 inclusive.
|
||||||
|
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||||
|
https://aclanthology.org/P02-1040.pdf
|
||||||
|
"""
|
||||||
|
from nltk.translate.bleu_score import ( # type: ignore
|
||||||
|
SmoothingFunction,
|
||||||
|
sentence_bleu,
|
||||||
|
)
|
||||||
|
|
||||||
|
hypotheses = source[0].split()
|
||||||
|
references = [s.split() for s in example]
|
||||||
|
|
||||||
|
return float(
|
||||||
|
sentence_bleu(
|
||||||
|
references,
|
||||||
|
hypotheses,
|
||||||
|
smoothing_function=SmoothingFunction().method1,
|
||||||
|
auto_reweigh=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
|
||||||
|
"""Select and order examples based on ngram overlap score (sentence_bleu score).
|
||||||
|
|
||||||
|
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
|
||||||
|
https://aclanthology.org/P02-1040.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
examples: List[dict]
|
||||||
|
"""A list of the examples that the prompt template expects."""
|
||||||
|
|
||||||
|
example_prompt: PromptTemplate
|
||||||
|
"""Prompt template used to format the examples."""
|
||||||
|
|
||||||
|
threshold: float = -1.0
|
||||||
|
"""Threshold at which algorithm stops. Set to -1.0 by default.
|
||||||
|
|
||||||
|
For negative threshold:
|
||||||
|
select_examples sorts examples by ngram_overlap_score, but excludes none.
|
||||||
|
For threshold greater than 1.0:
|
||||||
|
select_examples excludes all examples, and returns an empty list.
|
||||||
|
For threshold equal to 0.0:
|
||||||
|
select_examples sorts examples by ngram_overlap_score,
|
||||||
|
and excludes examples with no ngram overlap with input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def check_dependencies(cls, values: Dict) -> Dict:
|
||||||
|
"""Check that valid dependencies exist."""
|
||||||
|
try:
|
||||||
|
from nltk.translate.bleu_score import ( # noqa: disable=F401
|
||||||
|
SmoothingFunction,
|
||||||
|
sentence_bleu,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Not all the correct dependencies for this ExampleSelect exist"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def add_example(self, example: Dict[str, str]) -> None:
|
||||||
|
"""Add new example to list."""
|
||||||
|
self.examples.append(example)
|
||||||
|
|
||||||
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
|
"""Return list of examples sorted by ngram_overlap_score with input.
|
||||||
|
|
||||||
|
Descending order.
|
||||||
|
Excludes any examples with ngram_overlap_score less than or equal to threshold.
|
||||||
|
"""
|
||||||
|
inputs = list(input_variables.values())
|
||||||
|
examples = []
|
||||||
|
k = len(self.examples)
|
||||||
|
score = [0.0] * k
|
||||||
|
first_prompt_template_key = self.example_prompt.input_variables[0]
|
||||||
|
|
||||||
|
for i in range(k):
|
||||||
|
score[i] = ngram_overlap_score(
|
||||||
|
inputs, [self.examples[i][first_prompt_template_key]]
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
arg_max = np.argmax(score)
|
||||||
|
if (score[arg_max] < self.threshold) or abs(
|
||||||
|
score[arg_max] - self.threshold
|
||||||
|
) < 1e-9:
|
||||||
|
break
|
||||||
|
|
||||||
|
examples.append(self.examples[arg_max])
|
||||||
|
score[arg_max] = self.threshold - 1.0
|
||||||
|
|
||||||
|
return examples
|
@ -0,0 +1,73 @@
|
|||||||
|
"""Test functionality related to ngram overlap based selector."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.prompts.example_selector.ngram_overlap import (
|
||||||
|
NGramOverlapExampleSelector,
|
||||||
|
ngram_overlap_score,
|
||||||
|
)
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
EXAMPLES = [
|
||||||
|
{"input": "See Spot run.", "output": "foo1"},
|
||||||
|
{"input": "My dog barks.", "output": "foo2"},
|
||||||
|
{"input": "Spot can run.", "output": "foo3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def selector() -> NGramOverlapExampleSelector:
|
||||||
|
"""Get ngram overlap based selector to use in tests."""
|
||||||
|
prompts = PromptTemplate(
|
||||||
|
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
|
||||||
|
)
|
||||||
|
selector = NGramOverlapExampleSelector(
|
||||||
|
examples=EXAMPLES,
|
||||||
|
example_prompt=prompts,
|
||||||
|
)
|
||||||
|
return selector
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
|
||||||
|
"""Test NGramOverlapExampleSelector can select examples."""
|
||||||
|
sentence = "Spot can run."
|
||||||
|
output = selector.select_examples({"input": sentence})
|
||||||
|
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
|
||||||
|
"""Test NGramOverlapExampleSelector can add an example."""
|
||||||
|
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
|
||||||
|
selector.add_example(new_example)
|
||||||
|
sentence = "Spot can run."
|
||||||
|
output = selector.select_examples({"input": sentence})
|
||||||
|
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
|
||||||
|
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
|
||||||
|
selector.threshold = 0.0
|
||||||
|
sentence = "Spot can run."
|
||||||
|
output = selector.select_examples({"input": sentence})
|
||||||
|
assert output == [EXAMPLES[2], EXAMPLES[0]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_selector_threshold_more_than_one(
|
||||||
|
selector: NGramOverlapExampleSelector,
|
||||||
|
) -> None:
|
||||||
|
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
|
||||||
|
selector.threshold = 1.0 + 1e-9
|
||||||
|
sentence = "Spot can run."
|
||||||
|
output = selector.select_examples({"input": sentence})
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
|
||||||
|
"""Tests that ngram_overlap_score returns correct values."""
|
||||||
|
selector.threshold = 1.0 + 1e-9
|
||||||
|
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
|
||||||
|
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
|
||||||
|
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
|
||||||
|
|
||||||
|
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
|
||||||
|
assert check == [True, True, True]
|
Loading…
Reference in New Issue
Block a user