langchain/libs/community/langchain_community/example_selectors/ngram_overlap.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

117 lines
3.8 KiB
Python

"""Select and order examples based on ngram overlap score (sentence_bleu score).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from typing import Any, Dict, List
import numpy as np
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, model_validator
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
"""Compute ngram overlap score of source and example as sentence_bleu score
from NLTK package.
Use sentence_bleu with method1 smoothing function and auto reweighting.
Return float value between 0.0 and 1.0 inclusive.
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore
sentence_bleu,
)
hypotheses = source[0].split()
references = [s.split() for s in example]
return float(
sentence_bleu(
references,
hypotheses,
smoothing_function=SmoothingFunction().method1,
auto_reweigh=True,
)
)
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
"""Select and order examples based on ngram overlap score (sentence_bleu score
from NLTK package).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
threshold: float = -1.0
"""Threshold at which algorithm stops. Set to -1.0 by default.
For negative threshold:
select_examples sorts examples by ngram_overlap_score, but excludes none.
For threshold greater than 1.0:
select_examples excludes all examples, and returns an empty list.
For threshold equal to 0.0:
select_examples sorts examples by ngram_overlap_score,
and excludes examples with no ngram overlap with input.
"""
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, values: Dict) -> Any:
"""Check that valid dependencies exist."""
try:
from nltk.translate.bleu_score import ( # noqa: F401
SmoothingFunction,
sentence_bleu,
)
except ImportError as e:
raise ImportError(
"Not all the correct dependencies for this ExampleSelect exist."
"Please install nltk with `pip install nltk`."
) from e
return values
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Return list of examples sorted by ngram_overlap_score with input.
Descending order.
Excludes any examples with ngram_overlap_score less than or equal to threshold.
"""
inputs = list(input_variables.values())
examples = []
k = len(self.examples)
score = [0.0] * k
first_prompt_template_key = self.example_prompt.input_variables[0]
for i in range(k):
score[i] = ngram_overlap_score(
inputs, [self.examples[i][first_prompt_template_key]]
)
while True:
arg_max = np.argmax(score)
if (score[arg_max] < self.threshold) or abs(
score[arg_max] - self.threshold
) < 1e-9:
break
examples.append(self.examples[arg_max])
score[arg_max] = self.threshold - 1.0
return examples