mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Adds embeddings filter option to return scores in state (#12489)
CC @baskaryan @assafelovic
This commit is contained in:
parent
18601bd4c8
commit
76283e9625
@ -67,4 +67,6 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
similarity[included_idxs] > self.similarity_threshold
|
||||
)
|
||||
included_idxs = included_idxs[similar_enough]
|
||||
for i in included_idxs:
|
||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
|
@ -35,7 +35,9 @@ def test_embeddings_filter_with_state() -> None:
|
||||
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
docs[-1].state = {"embedded_doc": embedded_query}
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
relevant_filter = EmbeddingsFilter(
|
||||
embeddings=embeddings, similarity_threshold=0.75, return_similarity_scores=True
|
||||
)
|
||||
actual = relevant_filter.compress_documents(docs, query)
|
||||
assert len(actual) == 1
|
||||
assert texts[-1] == actual[0].page_content
|
||||
|
Loading…
Reference in New Issue
Block a user