diff --git a/langchain/prompts/example_selector/semantic_similarity.py b/langchain/prompts/example_selector/semantic_similarity.py index 499bd9fc7a8..a78ca12daae 100644 --- a/langchain/prompts/example_selector/semantic_similarity.py +++ b/langchain/prompts/example_selector/semantic_similarity.py @@ -8,6 +8,11 @@ from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.vectorstores.base import VectorStore +def sorted_values(values: Dict[str, str]) -> List[Any]: + """Return a list of values in dict sorted by key.""" + return [values[val] for val in sorted(values)] + + class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): """Example selector that selects examples based on SemanticSimilarity.""" @@ -26,13 +31,13 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): def add_example(self, example: Dict[str, str]) -> None: """Add new example to vectorstore.""" - string_example = " ".join(example.values()) + string_example = " ".join(sorted_values(example)) self.vectorstore.add_texts([string_example], metadatas=[example]) def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on semantic similarity.""" # Get the docs with the highest similarity. - query = " ".join(input_variables.values()) + query = " ".join(sorted_values(input_variables)) example_docs = self.vectorstore.similarity_search(query, k=self.k) # Get the examples from the metadata. # This assumes that examples are stored in metadata. @@ -73,7 +78,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): Returns: The ExampleSelector instantiated, backed by a vector store. """ - string_examples = [" ".join(eg.values()) for eg in examples] + string_examples = [" ".join(sorted_values(eg)) for eg in examples] vectorstore = vectorstore_cls.from_texts( string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs ) diff --git a/tests/unit_tests/prompts/test_utils.py b/tests/unit_tests/prompts/test_utils.py new file mode 100644 index 00000000000..479d02e8bd9 --- /dev/null +++ b/tests/unit_tests/prompts/test_utils.py @@ -0,0 +1,9 @@ +"""Test functionality related to prompt utils.""" +from langchain.prompts.example_selector.semantic_similarity import sorted_values + + +def test_sorted_vals() -> None: + """Test sorted values from dictionary.""" + test_dict = {"key2": "val2", "key1": "val1"} + expected_response = ["val1", "val2"] + assert sorted_values(test_dict) == expected_response