Files
langchain/libs/community/langchain_community/example_selectors/ngram_overlap.py
Harrison Chase 8516a03a02 langchain-community[major]: Upgrade community to pydantic 2 (#26011)
This PR upgrades langchain-community to pydantic 2.


* Most of this PR was auto-generated using code mods with gritql
(https://github.com/eyurtsev/migrate-pydantic/tree/main)
* Subsequently, some code was fixed manually due to accommodate
differences between pydantic 1 and 2

Breaking Changes:

- Use TEXTEMBED_API_KEY and TEXTEMBEB_API_URL for env variables for text
embed integrations:
cbea780492

Other changes:

- Added pydantic_settings as a required dependency for community. This
may be removed if we have enough time to convert the dependency into an
optional one.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-09-05 14:07:10 -04:00

117 lines
3.8 KiB
Python

"""Select and order examples based on ngram overlap score (sentence_bleu score).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from typing import Any, Dict, List
import numpy as np
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts import PromptTemplate
from pydantic import BaseModel, model_validator
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
"""Compute ngram overlap score of source and example as sentence_bleu score
from NLTK package.
Use sentence_bleu with method1 smoothing function and auto reweighting.
Return float value between 0.0 and 1.0 inclusive.
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
from nltk.translate.bleu_score import (
SmoothingFunction, # type: ignore
sentence_bleu,
)
hypotheses = source[0].split()
references = [s.split() for s in example]
return float(
sentence_bleu(
references,
hypotheses,
smoothing_function=SmoothingFunction().method1,
auto_reweigh=True,
)
)
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
"""Select and order examples based on ngram overlap score (sentence_bleu score
from NLTK package).
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
https://aclanthology.org/P02-1040.pdf
"""
examples: List[dict]
"""A list of the examples that the prompt template expects."""
example_prompt: PromptTemplate
"""Prompt template used to format the examples."""
threshold: float = -1.0
"""Threshold at which algorithm stops. Set to -1.0 by default.
For negative threshold:
select_examples sorts examples by ngram_overlap_score, but excludes none.
For threshold greater than 1.0:
select_examples excludes all examples, and returns an empty list.
For threshold equal to 0.0:
select_examples sorts examples by ngram_overlap_score,
and excludes examples with no ngram overlap with input.
"""
@model_validator(mode="before")
@classmethod
def check_dependencies(cls, values: Dict) -> Any:
"""Check that valid dependencies exist."""
try:
from nltk.translate.bleu_score import ( # noqa: F401
SmoothingFunction,
sentence_bleu,
)
except ImportError as e:
raise ImportError(
"Not all the correct dependencies for this ExampleSelect exist."
"Please install nltk with `pip install nltk`."
) from e
return values
def add_example(self, example: Dict[str, str]) -> None:
"""Add new example to list."""
self.examples.append(example)
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Return list of examples sorted by ngram_overlap_score with input.
Descending order.
Excludes any examples with ngram_overlap_score less than or equal to threshold.
"""
inputs = list(input_variables.values())
examples = []
k = len(self.examples)
score = [0.0] * k
first_prompt_template_key = self.example_prompt.input_variables[0]
for i in range(k):
score[i] = ngram_overlap_score(
inputs, [self.examples[i][first_prompt_template_key]]
)
while True:
arg_max = np.argmax(score)
if (score[arg_max] < self.threshold) or abs(
score[arg_max] - self.threshold
) < 1e-9:
break
examples.append(self.examples[arg_max])
score[arg_max] = self.threshold - 1.0
return examples