mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
core[minor]: Add async methods to MaxMarginalRelevanceExampleSelector (#19639)
This commit is contained in:
committed by
GitHub
parent
72c8b3127d
commit
33fa8cfcd0
@@ -2,7 +2,10 @@ from typing import Any, Iterable, List, Optional, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings, FakeEmbeddings
|
||||
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
||||
from langchain_core.example_selectors import (
|
||||
MaxMarginalRelevanceExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
|
||||
@@ -32,7 +35,24 @@ class DummyVectorStore(VectorStore):
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(page_content=query, metadata={"metadata": query, "other": "other"})
|
||||
Document(
|
||||
page_content=query, metadata={"query": query, "k": k, "other": "other"}
|
||||
)
|
||||
] * k
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=query,
|
||||
metadata={"query": query, "k": k, "fetch_k": fetch_k, "other": "other"},
|
||||
)
|
||||
] * k
|
||||
|
||||
@classmethod
|
||||
@@ -72,19 +92,19 @@ async def test_aadd_example() -> None:
|
||||
def test_select_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
|
||||
)
|
||||
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"metadata": "bar2"}] * 2
|
||||
assert examples == [{"query": "bar2", "k": 2}] * 2
|
||||
|
||||
|
||||
async def test_aselect_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["metadata"], k=2
|
||||
vectorstore=vector_store, input_keys=["foo2"], example_keys=["query", "k"], k=2
|
||||
)
|
||||
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"metadata": "bar2"}] * 2
|
||||
assert examples == [{"query": "bar2", "k": 2}] * 2
|
||||
|
||||
|
||||
def test_from_examples() -> None:
|
||||
@@ -137,3 +157,85 @@ async def test_afrom_examples() -> None:
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
||||
|
||||
def test_mmr_select_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = MaxMarginalRelevanceExampleSelector(
|
||||
vectorstore=vector_store,
|
||||
input_keys=["foo2"],
|
||||
example_keys=["query", "k", "fetch_k"],
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
)
|
||||
examples = selector.select_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
|
||||
|
||||
|
||||
async def test_mmr_aselect_examples() -> None:
|
||||
vector_store = DummyVectorStore()
|
||||
selector = MaxMarginalRelevanceExampleSelector(
|
||||
vectorstore=vector_store,
|
||||
input_keys=["foo2"],
|
||||
example_keys=["query", "k", "fetch_k"],
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
)
|
||||
examples = await selector.aselect_examples({"foo": "bar", "foo2": "bar2"})
|
||||
assert examples == [{"query": "bar2", "k": 2, "fetch_k": 5}] * 2
|
||||
|
||||
|
||||
def test_mmr_from_examples() -> None:
|
||||
examples = [{"foo": "bar"}]
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
selector = MaxMarginalRelevanceExampleSelector.from_examples(
|
||||
examples=examples,
|
||||
embeddings=embeddings,
|
||||
vectorstore_cls=DummyVectorStore,
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
input_keys=["foo"],
|
||||
example_keys=["some_example_key"],
|
||||
vectorstore_kwargs={"vs_foo": "vs_bar"},
|
||||
init_arg="some_init_arg",
|
||||
)
|
||||
assert selector.input_keys == ["foo"]
|
||||
assert selector.example_keys == ["some_example_key"]
|
||||
assert selector.k == 2
|
||||
assert selector.fetch_k == 5
|
||||
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
|
||||
|
||||
assert isinstance(selector.vectorstore, DummyVectorStore)
|
||||
vector_store = cast(DummyVectorStore, selector.vectorstore)
|
||||
assert vector_store.embeddings is embeddings
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
||||
|
||||
async def test_mmr_afrom_examples() -> None:
|
||||
examples = [{"foo": "bar"}]
|
||||
embeddings = FakeEmbeddings(size=1)
|
||||
selector = await MaxMarginalRelevanceExampleSelector.afrom_examples(
|
||||
examples=examples,
|
||||
embeddings=embeddings,
|
||||
vectorstore_cls=DummyVectorStore,
|
||||
k=2,
|
||||
fetch_k=5,
|
||||
input_keys=["foo"],
|
||||
example_keys=["some_example_key"],
|
||||
vectorstore_kwargs={"vs_foo": "vs_bar"},
|
||||
init_arg="some_init_arg",
|
||||
)
|
||||
assert selector.input_keys == ["foo"]
|
||||
assert selector.example_keys == ["some_example_key"]
|
||||
assert selector.k == 2
|
||||
assert selector.fetch_k == 5
|
||||
assert selector.vectorstore_kwargs == {"vs_foo": "vs_bar"}
|
||||
|
||||
assert isinstance(selector.vectorstore, DummyVectorStore)
|
||||
vector_store = cast(DummyVectorStore, selector.vectorstore)
|
||||
assert vector_store.embeddings is embeddings
|
||||
assert vector_store.init_arg == "some_init_arg"
|
||||
assert vector_store.texts == ["bar"]
|
||||
assert vector_store.metadatas == [{"foo": "bar"}]
|
||||
|
Reference in New Issue
Block a user