mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 12:01:47 +00:00
parent
1db7b18341
commit
d368c43648
@ -8,6 +8,11 @@ from langchain.prompts.example_selector.base import BaseExampleSelector
|
|||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||||
|
"""Return a list of values in dict sorted by key."""
|
||||||
|
return [values[val] for val in sorted(values)]
|
||||||
|
|
||||||
|
|
||||||
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||||
|
|
||||||
@ -26,13 +31,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
|
|
||||||
def add_example(self, example: Dict[str, str]) -> None:
|
def add_example(self, example: Dict[str, str]) -> None:
|
||||||
"""Add new example to vectorstore."""
|
"""Add new example to vectorstore."""
|
||||||
string_example = " ".join(example.values())
|
string_example = " ".join(sorted_values(example))
|
||||||
self.vectorstore.add_texts([string_example], metadatas=[example])
|
self.vectorstore.add_texts([string_example], metadatas=[example])
|
||||||
|
|
||||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||||
"""Select which examples to use based on semantic similarity."""
|
"""Select which examples to use based on semantic similarity."""
|
||||||
# Get the docs with the highest similarity.
|
# Get the docs with the highest similarity.
|
||||||
query = " ".join(input_variables.values())
|
query = " ".join(sorted_values(input_variables))
|
||||||
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
example_docs = self.vectorstore.similarity_search(query, k=self.k)
|
||||||
# Get the examples from the metadata.
|
# Get the examples from the metadata.
|
||||||
# This assumes that examples are stored in metadata.
|
# This assumes that examples are stored in metadata.
|
||||||
@ -73,7 +78,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
The ExampleSelector instantiated, backed by a vector store.
|
The ExampleSelector instantiated, backed by a vector store.
|
||||||
"""
|
"""
|
||||||
string_examples = [" ".join(eg.values()) for eg in examples]
|
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||||
vectorstore = vectorstore_cls.from_texts(
|
vectorstore = vectorstore_cls.from_texts(
|
||||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||||
)
|
)
|
||||||
|
9
tests/unit_tests/prompts/test_utils.py
Normal file
9
tests/unit_tests/prompts/test_utils.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
"""Test functionality related to prompt utils."""
|
||||||
|
from langchain.prompts.example_selector.semantic_similarity import sorted_values
|
||||||
|
|
||||||
|
|
||||||
|
def test_sorted_vals() -> None:
|
||||||
|
"""Test sorted values from dictionary."""
|
||||||
|
test_dict = {"key2": "val2", "key1": "val1"}
|
||||||
|
expected_response = ["val1", "val2"]
|
||||||
|
assert sorted_values(test_dict) == expected_response
|
Loading…
Reference in New Issue
Block a user