diff --git a/libs/core/langchain_core/example_selectors/semantic_similarity.py b/libs/core/langchain_core/example_selectors/semantic_similarity.py index c3bb86fa9c5..1a6a9044e98 100644 --- a/libs/core/langchain_core/example_selectors/semantic_similarity.py +++ b/libs/core/langchain_core/example_selectors/semantic_similarity.py @@ -1,6 +1,7 @@ """Example selector that selects examples based on SemanticSimilarity.""" from __future__ import annotations +from abc import ABC from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from langchain_core.documents import Document @@ -17,7 +18,7 @@ def sorted_values(values: Dict[str, str]) -> List[Any]: return [values[val] for val in sorted(values)] -class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): +class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC): """Example selector that selects examples based on SemanticSimilarity.""" vectorstore: VectorStore @@ -70,6 +71,10 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): ) return ids[0] + +class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector): + """Example selector that selects examples based on SemanticSimilarity.""" + def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: """Select which examples to use based on semantic similarity.""" # Get the docs with the highest similarity. @@ -116,6 +121,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): k: Number of examples to select input_keys: If provided, the search is based on the input variables instead of all variables. + example_keys: If provided, keys to filter examples to. + vectorstore_kwargs: Extra arguments passed to similarity_search function + of the vectorstore. vectorstore_cls_kwargs: optional kwargs containing url for vector store Returns: @@ -157,6 +165,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): k: Number of examples to select input_keys: If provided, the search is based on the input variables instead of all variables. + example_keys: If provided, keys to filter examples to. + vectorstore_kwargs: Extra arguments passed to similarity_search function + of the vectorstore. vectorstore_cls_kwargs: optional kwargs containing url for vector store Returns: @@ -175,7 +186,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel): ) -class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): +class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector): """ExampleSelector that selects examples based on Max Marginal Relevance. This was shown to improve performance in this paper: @@ -186,21 +197,20 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): """Number of examples to fetch to rerank.""" def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: - """Select which examples to use based on semantic similarity.""" - # Get the docs with the highest similarity. - if self.input_keys: - input_variables = {key: input_variables[key] for key in self.input_keys} - query = " ".join(sorted_values(input_variables)) example_docs = self.vectorstore.max_marginal_relevance_search( - query, k=self.k, fetch_k=self.fetch_k + self._example_to_text(input_variables, self.input_keys), + k=self.k, + fetch_k=self.fetch_k, ) - # Get the examples from the metadata. - # This assumes that examples are stored in metadata. - examples = [dict(e.metadata) for e in example_docs] - # If example keys are provided, filter examples to those keys. - if self.example_keys: - examples = [{k: eg[k] for k in self.example_keys} for eg in examples] - return examples + return self._documents_to_examples(example_docs) + + async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]: + example_docs = await self.vectorstore.amax_marginal_relevance_search( + self._example_to_text(input_variables, self.input_keys), + k=self.k, + fetch_k=self.fetch_k, + ) + return self._documents_to_examples(example_docs) @classmethod def from_examples( @@ -211,32 +221,86 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector): k: int = 4, input_keys: Optional[List[str]] = None, fetch_k: int = 20, + example_keys: Optional[List[str]] = None, + vectorstore_kwargs: Optional[dict] = None, **vectorstore_cls_kwargs: Any, ) -> MaxMarginalRelevanceExampleSelector: """Create k-shot example selector using example list and embeddings. - Reshuffles examples dynamically based on query similarity. + Reshuffles examples dynamically based on Max Marginal Relevance. Args: examples: List of examples to use in the prompt. - embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings(). + embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings(). vectorstore_cls: A vector store DB interface class, e.g. FAISS. k: Number of examples to select + fetch_k: Number of Documents to fetch to pass to MMR algorithm. input_keys: If provided, the search is based on the input variables instead of all variables. + example_keys: If provided, keys to filter examples to. + vectorstore_kwargs: Extra arguments passed to similarity_search function + of the vectorstore. vectorstore_cls_kwargs: optional kwargs containing url for vector store Returns: The ExampleSelector instantiated, backed by a vector store. """ - if input_keys: - string_examples = [ - " ".join(sorted_values({k: eg[k] for k in input_keys})) - for eg in examples - ] - else: - string_examples = [" ".join(sorted_values(eg)) for eg in examples] + string_examples = [cls._example_to_text(eg, input_keys) for eg in examples] vectorstore = vectorstore_cls.from_texts( string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs ) - return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys) + return cls( + vectorstore=vectorstore, + k=k, + fetch_k=fetch_k, + input_keys=input_keys, + example_keys=example_keys, + vectorstore_kwargs=vectorstore_kwargs, + ) + + @classmethod + async def afrom_examples( + cls, + examples: List[dict], + embeddings: Embeddings, + vectorstore_cls: Type[VectorStore], + *, + k: int = 4, + input_keys: Optional[List[str]] = None, + fetch_k: int = 20, + example_keys: Optional[List[str]] = None, + vectorstore_kwargs: Optional[dict] = None, + **vectorstore_cls_kwargs: Any, + ) -> MaxMarginalRelevanceExampleSelector: + """Create k-shot example selector using example list and embeddings. + + Reshuffles examples dynamically based on Max Marginal Relevance. + + Args: + examples: List of examples to use in the prompt. + embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings(). + vectorstore_cls: A vector store DB interface class, e.g. FAISS. + k: Number of examples to select + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + input_keys: If provided, the search is based on the input variables + instead of all variables. + example_keys: If provided, keys to filter examples to. + vectorstore_kwargs: Extra arguments passed to similarity_search function + of the vectorstore. + vectorstore_cls_kwargs: optional kwargs containing url for vector store + + Returns: + The ExampleSelector instantiated, backed by a vector store. + """ + string_examples = [cls._example_to_text(eg, input_keys) for eg in examples] + vectorstore = await vectorstore_cls.afrom_texts( + string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs + ) + return cls( + vectorstore=vectorstore, + k=k, + fetch_k=fetch_k, + input_keys=input_keys, + example_keys=example_keys, + vectorstore_kwargs=vectorstore_kwargs, + ) diff --git a/libs/core/langchain_core/vectorstores.py b/libs/core/langchain_core/vectorstores.py index 64713c0139f..d6299c61d77 100644 --- a/libs/core/langchain_core/vectorstores.py +++ b/libs/core/langchain_core/vectorstores.py @@ -462,7 +462,22 @@ class VectorStore(ABC): lambda_mult: float = 0.5, **kwargs: Any, ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search diff --git a/libs/core/tests/unit_tests/example_selectors/test_similarity.py b/libs/core/tests/unit_tests/example_selectors/test_similarity.py index 3f6f0972f7a..2cd50ca8dd2 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_similarity.py +++ b/libs/core/tests/unit_tests/example_selectors/test_similarity.py @@ -2,7 +2,10 @@ from typing import Any, Iterable, List, Optional, cast from langchain_core.documents import Document from langchain_core.embeddings import Embeddings, FakeEmbeddings -from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_core.example_selectors import ( + MaxMarginalRelevanceExampleSelector, + SemanticSimilarityExampleSelector, +) from langchain_core.vectorstores import VectorStore @@ -32,7 +35,24 @@ class DummyVectorStore(VectorStore): self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: return [ - Document(page_content=query, metadata={"metadata": query, "other": "other"}) + Document( + page_content=query, metadata={"query": query, "k": k, "other": "other"} + ) + ] * k + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + return [ + Document( + page_content=query, + metadata={"query": query, "k": k, "fetch_k": fetch_k, "other": "other"}, + ) ] * k @classmethod @@ -72,19 +92,19 @@ async def test_aadd_example() -> None: def test_select_examples() -> None: vector_store = DummyVectorStore() selector = SemanticSimilarityExampleSelector( - vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2 + vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2 ) examples = selector.select_examples({"foo": "bar", "foo2": "bar2"}) - assert examples == [{"metadata": "bar2"}] * 2 + assert examples == [{"query": "bar2", "k": 2}] * 2 async def test_aselect_examples() -> None: vector_store = DummyVectorStore() selector = SemanticSimilarityExampleSelector( - vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2 + vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2 ) examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"}) - assert examples == [{"metadata": "bar2"}] * 2 + assert examples == [{"query": "bar2", "k": 2}] * 2 def test_from_examples() -> None: @@ -137,3 +157,85 @@ async def test_afrom_examples() -> None: assert vector_store.init_arg == "some_init_arg" assert vector_store.texts == ["bar"] assert vector_store.metadatas == [{"foo": "bar"}] + + +def test_mmr_select_examples() -> None: + vector_store = DummyVectorStore() + selector = MaxMarginalRelevanceExampleSelector( + vectorstore=vector_store, + input_keys=["foo2"], + example_keys=["query", "k", "fetch_k"], + k=2, + fetch_k=5, + ) + examples = selector.select_examples({"foo": "bar", "foo2": "bar2"}) + assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2 + + +async def test_mmr_aselect_examples() -> None: + vector_store = DummyVectorStore() + selector = MaxMarginalRelevanceExampleSelector( + vectorstore=vector_store, + input_keys=["foo2"], + example_keys=["query", "k", "fetch_k"], + k=2, + fetch_k=5, + ) + examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"}) + assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2 + + +def test_mmr_from_examples() -> None: + examples = [{"foo": "bar"}] + embeddings = FakeEmbeddings(size=1) + selector = MaxMarginalRelevanceExampleSelector.from_examples( + examples=examples, + embeddings=embeddings, + vectorstore_cls=DummyVectorStore, + k=2, + fetch_k=5, + input_keys=["foo"], + example_keys=["some_example_key"], + vectorstore_kwargs={"vs_foo": "vs_bar"}, + init_arg="some_init_arg", + ) + assert selector.input_keys == ["foo"] + assert selector.example_keys == ["some_example_key"] + assert selector.k == 2 + assert selector.fetch_k == 5 + assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"} + + assert isinstance(selector.vectorstore, DummyVectorStore) + vector_store = cast(DummyVectorStore, selector.vectorstore) + assert vector_store.embeddings is embeddings + assert vector_store.init_arg == "some_init_arg" + assert vector_store.texts == ["bar"] + assert vector_store.metadatas == [{"foo": "bar"}] + + +async def test_mmr_afrom_examples() -> None: + examples = [{"foo": "bar"}] + embeddings = FakeEmbeddings(size=1) + selector = await MaxMarginalRelevanceExampleSelector.afrom_examples( + examples=examples, + embeddings=embeddings, + vectorstore_cls=DummyVectorStore, + k=2, + fetch_k=5, + input_keys=["foo"], + example_keys=["some_example_key"], + vectorstore_kwargs={"vs_foo": "vs_bar"}, + init_arg="some_init_arg", + ) + assert selector.input_keys == ["foo"] + assert selector.example_keys == ["some_example_key"] + assert selector.k == 2 + assert selector.fetch_k == 5 + assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"} + + assert isinstance(selector.vectorstore, DummyVectorStore) + vector_store = cast(DummyVectorStore, selector.vectorstore) + assert vector_store.embeddings is embeddings + assert vector_store.init_arg == "some_init_arg" + assert vector_store.texts == ["bar"] + assert vector_store.metadatas == [{"foo": "bar"}]