mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
core[minor]: Add async methods to MaxMarginalRelevanceExampleSelector (#19639)
This commit is contained in:
parent
72c8b3127d
commit
33fa8cfcd0
@ -1,6 +1,7 @@
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain_core.documents import Document
|
||||
@ -17,7 +18,7 @@ def sorted_values(values: Dict[str, str]) -> List[Any]:
|
||||
return [values[val] for val in sorted(values)]
|
||||
|
||||
|
||||
class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
class _VectorStoreExampleSelector(BaseExampleSelector, BaseModel, ABC):
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
@ -70,6 +71,10 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
)
|
||||
return ids[0]
|
||||
|
||||
|
||||
class SemanticSimilarityExampleSelector(_VectorStoreExampleSelector):
|
||||
"""Example selector that selects examples based on SemanticSimilarity."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
@ -116,6 +121,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
example_keys: If provided, keys to filter examples to.
|
||||
vectorstore_kwargs: Extra arguments passed to similarity_search function
|
||||
of the vectorstore.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
@ -157,6 +165,9 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
k: Number of examples to select
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
example_keys: If provided, keys to filter examples to.
|
||||
vectorstore_kwargs: Extra arguments passed to similarity_search function
|
||||
of the vectorstore.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
@ -175,7 +186,7 @@ class SemanticSimilarityExampleSelector(BaseExampleSelector, BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
class MaxMarginalRelevanceExampleSelector(_VectorStoreExampleSelector):
|
||||
"""ExampleSelector that selects examples based on Max Marginal Relevance.
|
||||
|
||||
This was shown to improve performance in this paper:
|
||||
@ -186,21 +197,20 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
"""Number of examples to fetch to rerank."""
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on semantic similarity."""
|
||||
# Get the docs with the highest similarity.
|
||||
if self.input_keys:
|
||||
input_variables = {key: input_variables[key] for key in self.input_keys}
|
||||
query = " ".join(sorted_values(input_variables))
|
||||
example_docs = self.vectorstore.max_marginal_relevance_search(
|
||||
query, k=self.k, fetch_k=self.fetch_k
|
||||
self._example_to_text(input_variables, self.input_keys),
|
||||
k=self.k,
|
||||
fetch_k=self.fetch_k,
|
||||
)
|
||||
# Get the examples from the metadata.
|
||||
# This assumes that examples are stored in metadata.
|
||||
examples = [dict(e.metadata) for e in example_docs]
|
||||
# If example keys are provided, filter examples to those keys.
|
||||
if self.example_keys:
|
||||
examples = [{k: eg[k] for k in self.example_keys} for eg in examples]
|
||||
return examples
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
async def aselect_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
example_docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
self._example_to_text(input_variables, self.input_keys),
|
||||
k=self.k,
|
||||
fetch_k=self.fetch_k,
|
||||
)
|
||||
return self._documents_to_examples(example_docs)
|
||||
|
||||
@classmethod
|
||||
def from_examples(
|
||||
@ -211,32 +221,86 @@ class MaxMarginalRelevanceExampleSelector(SemanticSimilarityExampleSelector):
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on query similarity.
|
||||
Reshuffles examples dynamically based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An iniialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
example_keys: If provided, keys to filter examples to.
|
||||
vectorstore_kwargs: Extra arguments passed to similarity_search function
|
||||
of the vectorstore.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
if input_keys:
|
||||
string_examples = [
|
||||
" ".join(sorted_values({k: eg[k] for k in input_keys}))
|
||||
for eg in examples
|
||||
]
|
||||
else:
|
||||
string_examples = [" ".join(sorted_values(eg)) for eg in examples]
|
||||
string_examples = [cls._example_to_text(eg, input_keys) for eg in examples]
|
||||
vectorstore = vectorstore_cls.from_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(vectorstore=vectorstore, k=k, fetch_k=fetch_k, input_keys=input_keys)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
input_keys=input_keys,
|
||||
example_keys=example_keys,
|
||||
vectorstore_kwargs=vectorstore_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def afrom_examples(
|
||||
cls,
|
||||
examples: List[dict],
|
||||
embeddings: Embeddings,
|
||||
vectorstore_cls: Type[VectorStore],
|
||||
*,
|
||||
k: int = 4,
|
||||
input_keys: Optional[List[str]] = None,
|
||||
fetch_k: int = 20,
|
||||
example_keys: Optional[List[str]] = None,
|
||||
vectorstore_kwargs: Optional[dict] = None,
|
||||
**vectorstore_cls_kwargs: Any,
|
||||
) -> MaxMarginalRelevanceExampleSelector:
|
||||
"""Create k-shot example selector using example list and embeddings.
|
||||
|
||||
Reshuffles examples dynamically based on Max Marginal Relevance.
|
||||
|
||||
Args:
|
||||
examples: List of examples to use in the prompt.
|
||||
embeddings: An initialized embedding API interface, e.g. OpenAIEmbeddings().
|
||||
vectorstore_cls: A vector store DB interface class, e.g. FAISS.
|
||||
k: Number of examples to select
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
input_keys: If provided, the search is based on the input variables
|
||||
instead of all variables.
|
||||
example_keys: If provided, keys to filter examples to.
|
||||
vectorstore_kwargs: Extra arguments passed to similarity_search function
|
||||
of the vectorstore.
|
||||
vectorstore_cls_kwargs: optional kwargs containing url for vector store
|
||||
|
||||
Returns:
|
||||
The ExampleSelector instantiated, backed by a vector store.
|
||||
"""
|
||||
string_examples = [cls._example_to_text(eg, input_keys) for eg in examples]
|
||||
vectorstore = await vectorstore_cls.afrom_texts(
|
||||
string_examples, embeddings, metadatas=examples, **vectorstore_cls_kwargs
|
||||
)
|
||||
return cls(
|
||||
vectorstore=vectorstore,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
input_keys=input_keys,
|
||||
example_keys=example_keys,
|
||||
vectorstore_kwargs=vectorstore_kwargs,
|
||||
)
|
||||
|
@ -462,7 +462,22 @@ class VectorStore(ABC):
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
|
@ -2,7 +2,10 @@ from typing import Any, Iterable, List, Optional, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
||||
from langchain_core.example_selectors import (
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
@ -32,7 +35,24 @@ class DummyVectorStore(VectorStore):
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(page_content=query, metadata={"metadata": query, "other": "other"})
|
||||
Document(
|
||||
page_content=query, metadata={"query": query, "k": k, "other": "other"}
|
||||
)
|
||||
] * k
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=query,
|
||||
metadata={"query": query, "k": k, "fetch_k": fetch_k, "other": "other"},
|
||||
)
|
||||
] * k
|
||||
|
||||
@classmethod
|
||||
@ -72,19 +92,19 @@ async def test_aadd_example() -> None:
|
||||
def test_select_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
|
||||
)
|
||||
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"metadata": "bar2"}] * 2
|
||||
assert examples == [{"query": "bar2", "k": 2}] * 2
|
||||
|
||||
|
||||
async def test_aselect_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
|
||||
)
|
||||
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"metadata": "bar2"}] * 2
|
||||
assert examples == [{"query": "bar2", "k": 2}] * 2
|
||||
|
||||
|
||||
def test_from_examples() -> None:
|
||||
@ -137,3 +157,85 @@ async def test_afrom_examples() -> None:
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
||||
|
||||
def test_mmr_select_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = MaxMarginalRelevanceExampleSelector(
|
||||
vectorstore=vector_store,
|
||||
input_keys=["foo2"],
|
||||
example_keys=["query", "k", "fetch_k"],
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
)
|
||||
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
|
||||
|
||||
|
||||
async def test_mmr_aselect_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = MaxMarginalRelevanceExampleSelector(
|
||||
vectorstore=vector_store,
|
||||
input_keys=["foo2"],
|
||||
example_keys=["query", "k", "fetch_k"],
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
)
|
||||
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
|
||||
|
||||
|
||||
def test_mmr_from_examples() -> None:
|
||||
examples = [{"foo": "bar"}]
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
selector = MaxMarginalRelevanceExampleSelector.from_examples(
|
||||
examples=examples,
|
||||
embeddings=embeddings,
|
||||
vectorstore_cls=DummyVectorStore,
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
input_keys=["foo"],
|
||||
example_keys=["some_example_key"],
|
||||
vectorstore_kwargs={"vs_foo": "vs_bar"},
|
||||
init_arg="some_init_arg",
|
||||
)
|
||||
assert selector.input_keys == ["foo"]
|
||||
assert selector.example_keys == ["some_example_key"]
|
||||
assert selector.k == 2
|
||||
assert selector.fetch_k == 5
|
||||
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
|
||||
|
||||
assert isinstance(selector.vectorstore, DummyVectorStore)
|
||||
vector_store = cast(DummyVectorStore, selector.vectorstore)
|
||||
assert vector_store.embeddings is embeddings
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
||||
|
||||
async def test_mmr_afrom_examples() -> None:
|
||||
examples = [{"foo": "bar"}]
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
selector = await MaxMarginalRelevanceExampleSelector.afrom_examples(
|
||||
examples=examples,
|
||||
embeddings=embeddings,
|
||||
vectorstore_cls=DummyVectorStore,
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
input_keys=["foo"],
|
||||
example_keys=["some_example_key"],
|
||||
vectorstore_kwargs={"vs_foo": "vs_bar"},
|
||||
init_arg="some_init_arg",
|
||||
)
|
||||
assert selector.input_keys == ["foo"]
|
||||
assert selector.example_keys == ["some_example_key"]
|
||||
assert selector.k == 2
|
||||
assert selector.fetch_k == 5
|
||||
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
|
||||
|
||||
assert isinstance(selector.vectorstore, DummyVectorStore)
|
||||
vector_store = cast(DummyVectorStore, selector.vectorstore)
|
||||
assert vector_store.embeddings is embeddings
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
Loading…
Reference in New Issue
Block a user