core[minor]: Add async methods to MaxMarginalRelevanceExampleSelector (#19639)

This commit is contained in:
Christophe Bornet
2024-03-27 21:03:18 +01:00
committed by GitHub
parent 72c8b3127d
commit 33fa8cfcd0
3 changed files with 213 additions and 32 deletions

View File

@@ -2,7 +2,10 @@ from typing import Any, Iterable, List, Optional, cast
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings, FakeEmbeddings
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.example_selectors import (
MaxMarginalRelevanceExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain_core.vectorstores import VectorStore
@@ -32,7 +35,24 @@ class DummyVectorStore(VectorStore):
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
return [
Document(page_content=query, metadata={"metadata": query, "other": "other"})
Document(
page_content=query, metadata={"query": query, "k": k, "other": "other"}
)
] * k
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
return [
Document(
page_content=query,
metadata={"query": query, "k": k, "fetch_k": fetch_k, "other": "other"},
)
] * k
@classmethod
@@ -72,19 +92,19 @@ async def test_aadd_example() -> None:
def test_select_examples() -> None:
vector_store = DummyVectorStore()
selector = SemanticSimilarityExampleSelector(
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
)
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
assert examples == [{"metadata": "bar2"}] * 2
assert examples == [{"query": "bar2", "k": 2}] * 2
async def test_aselect_examples() -> None:
vector_store = DummyVectorStore()
selector = SemanticSimilarityExampleSelector(
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
)
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
assert examples == [{"metadata": "bar2"}] * 2
assert examples == [{"query": "bar2", "k": 2}] * 2
def test_from_examples() -> None:
@@ -137,3 +157,85 @@ async def test_afrom_examples() -> None:
assert vector_store.init_arg == "some_init_arg"
assert vector_store.texts == ["bar"]
assert vector_store.metadatas == [{"foo": "bar"}]
def test_mmr_select_examples() -> None:
vector_store = DummyVectorStore()
selector = MaxMarginalRelevanceExampleSelector(
vectorstore=vector_store,
input_keys=["foo2"],
example_keys=["query", "k", "fetch_k"],
k=2,
fetch_k=5,
)
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
async def test_mmr_aselect_examples() -> None:
vector_store = DummyVectorStore()
selector = MaxMarginalRelevanceExampleSelector(
vectorstore=vector_store,
input_keys=["foo2"],
example_keys=["query", "k", "fetch_k"],
k=2,
fetch_k=5,
)
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
def test_mmr_from_examples() -> None:
examples = [{"foo": "bar"}]
embeddings = FakeEmbeddings(size=1)
selector = MaxMarginalRelevanceExampleSelector.from_examples(
examples=examples,
embeddings=embeddings,
vectorstore_cls=DummyVectorStore,
k=2,
fetch_k=5,
input_keys=["foo"],
example_keys=["some_example_key"],
vectorstore_kwargs={"vs_foo": "vs_bar"},
init_arg="some_init_arg",
)
assert selector.input_keys == ["foo"]
assert selector.example_keys == ["some_example_key"]
assert selector.k == 2
assert selector.fetch_k == 5
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
assert isinstance(selector.vectorstore, DummyVectorStore)
vector_store = cast(DummyVectorStore, selector.vectorstore)
assert vector_store.embeddings is embeddings
assert vector_store.init_arg == "some_init_arg"
assert vector_store.texts == ["bar"]
assert vector_store.metadatas == [{"foo": "bar"}]
async def test_mmr_afrom_examples() -> None:
examples = [{"foo": "bar"}]
embeddings = FakeEmbeddings(size=1)
selector = await MaxMarginalRelevanceExampleSelector.afrom_examples(
examples=examples,
embeddings=embeddings,
vectorstore_cls=DummyVectorStore,
k=2,
fetch_k=5,
input_keys=["foo"],
example_keys=["some_example_key"],
vectorstore_kwargs={"vs_foo": "vs_bar"},
init_arg="some_init_arg",
)
assert selector.input_keys == ["foo"]
assert selector.example_keys == ["some_example_key"]
assert selector.k == 2
assert selector.fetch_k == 5
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
assert isinstance(selector.vectorstore, DummyVectorStore)
vector_store = cast(DummyVectorStore, selector.vectorstore)
assert vector_store.embeddings is embeddings
assert vector_store.init_arg == "some_init_arg"
assert vector_store.texts == ["bar"]
assert vector_store.metadatas == [{"foo": "bar"}]