mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
langchain[minor]: Make EmbeddingsFilters async (#22737)
Add native async implementation for EmbeddingsFilter
This commit is contained in:
parent
b45bf78d2e
commit
23c22fcbc9
@ -75,6 +75,20 @@ def _get_embeddings_from_stateful_docs(
|
|||||||
return embedded_documents
|
return embedded_documents
|
||||||
|
|
||||||
|
|
||||||
|
async def _aget_embeddings_from_stateful_docs(
|
||||||
|
embeddings: Embeddings, documents: Sequence[_DocumentWithState]
|
||||||
|
) -> List[List[float]]:
|
||||||
|
if len(documents) and "embedded_doc" in documents[0].state:
|
||||||
|
embedded_documents = [doc.state["embedded_doc"] for doc in documents]
|
||||||
|
else:
|
||||||
|
embedded_documents = await embeddings.aembed_documents(
|
||||||
|
[d.page_content for d in documents]
|
||||||
|
)
|
||||||
|
for doc, embedding in zip(documents, embedded_documents):
|
||||||
|
doc.state["embedded_doc"] = embedding
|
||||||
|
return embedded_documents
|
||||||
|
|
||||||
|
|
||||||
def _filter_cluster_embeddings(
|
def _filter_cluster_embeddings(
|
||||||
embedded_documents: List[List[float]],
|
embedded_documents: List[List[float]],
|
||||||
num_clusters: int,
|
num_clusters: int,
|
||||||
|
@ -27,3 +27,24 @@ def test_document_compressor_pipeline() -> None:
|
|||||||
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
|
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
|
||||||
assert len(actual) == 1
|
assert len(actual) == 1
|
||||||
assert actual[0].page_content in texts[:2]
|
assert actual[0].page_content in texts[:2]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_adocument_compressor_pipeline() -> None:
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ")
|
||||||
|
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||||
|
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
|
||||||
|
pipeline_filter = DocumentCompressorPipeline(
|
||||||
|
transformers=[splitter, redundant_filter, relevant_filter]
|
||||||
|
)
|
||||||
|
texts = [
|
||||||
|
"This sentence is about cows",
|
||||||
|
"This sentence was about cows",
|
||||||
|
"foo bar baz",
|
||||||
|
]
|
||||||
|
docs = [Document(page_content=". ".join(texts))]
|
||||||
|
actual = await pipeline_filter.acompress_documents(
|
||||||
|
docs, "Tell me about farm animals"
|
||||||
|
)
|
||||||
|
assert len(actual) == 1
|
||||||
|
assert actual[0].page_content in texts[:2]
|
||||||
|
@ -23,6 +23,20 @@ def test_embeddings_filter() -> None:
|
|||||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def atest_embeddings_filter() -> None:
|
||||||
|
texts = [
|
||||||
|
"What happened to all of my cookies?",
|
||||||
|
"I wish there were better Italian restaurants in my neighborhood.",
|
||||||
|
"My favorite color is green",
|
||||||
|
]
|
||||||
|
docs = [Document(page_content=t) for t in texts]
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||||
|
actual = relevant_filter.compress_documents(docs, "What did I say about food?")
|
||||||
|
assert len(actual) == 2
|
||||||
|
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_embeddings_filter_with_state() -> None:
|
def test_embeddings_filter_with_state() -> None:
|
||||||
texts = [
|
texts = [
|
||||||
"What happened to all of my cookies?",
|
"What happened to all of my cookies?",
|
||||||
@ -41,3 +55,23 @@ def test_embeddings_filter_with_state() -> None:
|
|||||||
actual = relevant_filter.compress_documents(docs, query)
|
actual = relevant_filter.compress_documents(docs, query)
|
||||||
assert len(actual) == 1
|
assert len(actual) == 1
|
||||||
assert texts[-1] == actual[0].page_content
|
assert texts[-1] == actual[0].page_content
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembeddings_filter_with_state() -> None:
|
||||||
|
texts = [
|
||||||
|
"What happened to all of my cookies?",
|
||||||
|
"I wish there were better Italian restaurants in my neighborhood.",
|
||||||
|
"My favorite color is green",
|
||||||
|
]
|
||||||
|
query = "What did I say about food?"
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
embedded_query = embeddings.embed_query(query)
|
||||||
|
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||||
|
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||||
|
docs[-1].state = {"embedded_doc": embedded_query}
|
||||||
|
relevant_filter = EmbeddingsFilter( # type: ignore[call-arg]
|
||||||
|
embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True
|
||||||
|
)
|
||||||
|
actual = relevant_filter.compress_documents(docs, query)
|
||||||
|
assert len(actual) == 1
|
||||||
|
assert texts[-1] == actual[0].page_content
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pytest
|
||||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||||
|
|
||||||
@ -24,3 +25,25 @@ def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
|||||||
actual = retriever.invoke("Tell me about the Celtics")
|
actual = retriever.invoke("Tell me about the Celtics")
|
||||||
assert len(actual) == 2
|
assert len(actual) == 2
|
||||||
assert texts[-1] not in [d.page_content for d in actual]
|
assert texts[-1] not in [d.page_content for d in actual]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acontextual_compression_retriever_get_relevant_docs() -> None:
|
||||||
|
"""Test get_relevant_docs."""
|
||||||
|
texts = [
|
||||||
|
"This is a document about the Boston Celtics",
|
||||||
|
"The Boston Celtics won the game by 20 points",
|
||||||
|
"I simply love going to the movies",
|
||||||
|
]
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||||
|
base_retriever = FAISS.from_texts(texts, embedding=embeddings).as_retriever(
|
||||||
|
search_kwargs={"k": len(texts)}
|
||||||
|
)
|
||||||
|
retriever = ContextualCompressionRetriever(
|
||||||
|
base_compressor=base_compressor, base_retriever=base_retriever
|
||||||
|
)
|
||||||
|
|
||||||
|
actual = retriever.invoke("Tell me about the Celtics")
|
||||||
|
assert len(actual) == 2
|
||||||
|
assert texts[-1] not in [d.page_content for d in actual]
|
||||||
|
@ -86,3 +86,38 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|||||||
for i in included_idxs:
|
for i in included_idxs:
|
||||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||||
return [stateful_documents[i] for i in included_idxs]
|
return [stateful_documents[i] for i in included_idxs]
|
||||||
|
|
||||||
|
async def acompress_documents(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
query: str,
|
||||||
|
callbacks: Optional[Callbacks] = None,
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Filter documents based on similarity of their embeddings to the query."""
|
||||||
|
try:
|
||||||
|
from langchain_community.document_transformers.embeddings_redundant_filter import ( # noqa: E501
|
||||||
|
_aget_embeddings_from_stateful_docs,
|
||||||
|
get_stateful_documents,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"To use please install langchain-community "
|
||||||
|
"with `pip install langchain-community`."
|
||||||
|
)
|
||||||
|
stateful_documents = get_stateful_documents(documents)
|
||||||
|
embedded_documents = await _aget_embeddings_from_stateful_docs(
|
||||||
|
self.embeddings, stateful_documents
|
||||||
|
)
|
||||||
|
embedded_query = await self.embeddings.aembed_query(query)
|
||||||
|
similarity = self.similarity_fn([embedded_query], embedded_documents)[0]
|
||||||
|
included_idxs = np.arange(len(embedded_documents))
|
||||||
|
if self.k is not None:
|
||||||
|
included_idxs = np.argsort(similarity)[::-1][: self.k]
|
||||||
|
if self.similarity_threshold is not None:
|
||||||
|
similar_enough = np.where(
|
||||||
|
similarity[included_idxs] > self.similarity_threshold
|
||||||
|
)
|
||||||
|
included_idxs = included_idxs[similar_enough]
|
||||||
|
for i in included_idxs:
|
||||||
|
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||||
|
return [stateful_documents[i] for i in included_idxs]
|
||||||
|
Loading…
Reference in New Issue
Block a user