mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
parent
1db7b18341
commit
d368c43648
@ -8,6 +8,11 @@ from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
"""Return a list of values in dict sorted by key."""
|
||||
return [values[val] for val in sorted(values)]
|
||||
|
||||
|
||||
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
|
||||
@ -26,13 +31,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> None:
|
||||
"""Add new example to vectorstore."""
|
||||
string_example = " ".join(example.values())
|
||||
string_example = " ".join(sorted_values(example))
|
||||
self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
query = " ".join(input_variables.values())
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
@ -73,7 +78,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
string_examples = [" ".join(eg.values()) for eg in examples]
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
|
9
tests/unit_tests/prompts/test_utils.py
Normal file
9
tests/unit_tests/prompts/test_utils.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""Test functionality related to prompt utils."""
|
||||
from langchain.prompts.example_selector.semantic_similarity import sorted_values
|
||||
|
||||
|
||||
def test_sorted_vals() -> None:
|
||||
"""Test sorted values from dictionary."""
|
||||
test_dict = {"key2": "val2", "key1": "val1"}
|
||||
expected_response = ["val1", "val2"]
|
||||
assert sorted_values(test_dict) == expected_response
|
Loading…
Reference in New Issue
Block a user