mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
Adds embeddings filter option to return scores in state (#12489)
CC @baskaryan @assafelovic
This commit is contained in:
parent
18601bd4c8
commit
76283e9625
@ -67,4 +67,6 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|||||||
similarity[included_idxs] > self.similarity_threshold
|
similarity[included_idxs] > self.similarity_threshold
|
||||||
)
|
)
|
||||||
included_idxs = included_idxs[similar_enough]
|
included_idxs = included_idxs[similar_enough]
|
||||||
|
for i in included_idxs:
|
||||||
|
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||||
return [stateful_documents[i] for i in included_idxs]
|
return [stateful_documents[i] for i in included_idxs]
|
||||||
|
@ -35,7 +35,9 @@ def test_embeddings_filter_with_state() -> None:
|
|||||||
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||||
docs[-1].state = {"embedded_doc": embedded_query}
|
docs[-1].state = {"embedded_doc": embedded_query}
|
||||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
relevant_filter = EmbeddingsFilter(
|
||||||
|
embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True
|
||||||
|
)
|
||||||
actual = relevant_filter.compress_documents(docs, query)
|
actual = relevant_filter.compress_documents(docs, query)
|
||||||
assert len(actual) == 1
|
assert len(actual) == 1
|
||||||
assert texts[-1] == actual[0].page_content
|
assert texts[-1] == actual[0].page_content
|
||||||
|
Loading…
Reference in New Issue
Block a user