mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 20:09:01 +00:00
Contextual compression retriever (#2915)
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
0
tests/integration_tests/retrievers/__init__.py
Normal file
0
tests/integration_tests/retrievers/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Integration test for compression pipelines."""
|
||||
from langchain.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.document_compressors import (
|
||||
DocumentCompressorPipeline,
|
||||
EmbeddingsFilter,
|
||||
)
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
def test_document_compressor_pipeline() -> None:
|
||||
embeddings = OpenAIEmbeddings()
|
||||
splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ")
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
|
||||
pipeline_filter = DocumentCompressorPipeline(
|
||||
transformers=[splitter, redundant_filter, relevant_filter]
|
||||
)
|
||||
texts = [
|
||||
"This sentence is about cows",
|
||||
"This sentence was about cows",
|
||||
"foo bar baz",
|
||||
]
|
||||
docs = [Document(page_content=". ".join(texts))]
|
||||
actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals")
|
||||
assert len(actual) == 1
|
||||
assert actual[0].page_content in texts[:2]
|
@@ -0,0 +1,36 @@
|
||||
"""Integration test for LLMChainExtractor."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainExtractor
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_llm_chain_extractor() -> None:
|
||||
texts = [
|
||||
"The Roman Empire followed the Roman Republic.",
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"The first Roman emperor was Caesar Augustus.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"The Roman Empire collapsed in 476 AD after the fall of Rome.",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")[
|
||||
0
|
||||
].page_content
|
||||
expected_returned = [0, 2, 4]
|
||||
expected_not_returned = [1, 3, 5]
|
||||
assert all([texts[i] in actual for i in expected_returned])
|
||||
assert all([texts[i] not in actual for i in expected_not_returned])
|
||||
|
||||
|
||||
def test_llm_chain_extractor_empty() -> None:
|
||||
texts = [
|
||||
"I love chocolate chip cookies—my mother makes great cookies.",
|
||||
"Don't you just love Caesar salad?",
|
||||
"Let's go to Olive Garden!",
|
||||
]
|
||||
doc = Document(page_content=" ".join(texts))
|
||||
compressor = LLMChainExtractor.from_llm(ChatOpenAI())
|
||||
actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")
|
||||
assert len(actual) == 0
|
@@ -0,0 +1,17 @@
|
||||
"""Integration test for llm-based relevant doc filtering."""
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.retrievers.document_compressors import LLMChainFilter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_llm_chain_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
relevant_filter = LLMChainFilter.from_llm(llm=ChatOpenAI())
|
||||
actual = relevant_filter.compress_documents(docs, "Things I said related to food")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
@@ -0,0 +1,39 @@
|
||||
"""Integration test for embedding-based relevant doc filtering."""
|
||||
import numpy as np
|
||||
|
||||
from langchain.document_transformers import _DocumentWithState
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_embeddings_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
actual = relevant_filter.compress_documents(docs, "What did I say about food?")
|
||||
assert len(actual) == 2
|
||||
assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2
|
||||
|
||||
|
||||
def test_embeddings_filter_with_state() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
"My favorite color is green",
|
||||
]
|
||||
query = "What did I say about food?"
|
||||
embeddings = OpenAIEmbeddings()
|
||||
embedded_query = embeddings.embed_query(query)
|
||||
state = {"embedded_doc": np.zeros(len(embedded_query))}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
docs[-1].state = {"embedded_doc": embedded_query}
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
actual = relevant_filter.compress_documents(docs, query)
|
||||
assert len(actual) == 1
|
||||
assert texts[-1] == actual[0].page_content
|
@@ -0,0 +1,25 @@
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
|
||||
def test_contextual_compression_retriever_get_relevant_docs() -> None:
|
||||
"""Test get_relevant_docs."""
|
||||
texts = [
|
||||
"This is a document about the Boston Celtics",
|
||||
"The Boston Celtics won the game by 20 points",
|
||||
"I simply love going to the movies",
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75)
|
||||
base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(
|
||||
search_kwargs={"k": len(texts)}
|
||||
)
|
||||
retriever = ContextualCompressionRetriever(
|
||||
base_compressor=base_compressor, base_retriever=base_retriever
|
||||
)
|
||||
|
||||
actual = retriever.get_relevant_documents("Tell me about the Celtics")
|
||||
assert len(actual) == 2
|
||||
assert texts[-1] not in [d.page_content for d in actual]
|
31
tests/integration_tests/test_document_transformers.py
Normal file
31
tests/integration_tests/test_document_transformers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Integration test for embedding-based redundant doc filtering."""
|
||||
from langchain.document_transformers import (
|
||||
EmbeddingsRedundantFilter,
|
||||
_DocumentWithState,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter() -> None:
|
||||
texts = [
|
||||
"What happened to all of my cookies?",
|
||||
"Where did all of my cookies go?",
|
||||
"I wish there were better Italian restaurants in my neighborhood.",
|
||||
]
|
||||
docs = [Document(page_content=t) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 2
|
||||
assert set(texts[:2]).intersection([d.page_content for d in actual])
|
||||
|
||||
|
||||
def test_embeddings_redundant_filter_with_state() -> None:
|
||||
texts = ["What happened to all of my cookies?", "foo bar baz"]
|
||||
state = {"embedded_doc": [0.5] * 10}
|
||||
docs = [_DocumentWithState(page_content=t, state=state) for t in texts]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
actual = redundant_filter.transform_documents(docs)
|
||||
assert len(actual) == 1
|
15
tests/unit_tests/test_document_transformers.py
Normal file
15
tests/unit_tests/test_document_transformers.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Unit tests for document transformers."""
|
||||
from langchain.document_transformers import _filter_similar_embeddings
|
||||
from langchain.math_utils import cosine_similarity
|
||||
|
||||
|
||||
def test__filter_similar_embeddings() -> None:
|
||||
threshold = 0.79
|
||||
embedded_docs = [[1.0, 2.0], [1.0, 2.0], [2.0, 1.0], [2.0, 0.5], [0.0, 0.0]]
|
||||
expected = [1, 3, 4]
|
||||
actual = _filter_similar_embeddings(embedded_docs, cosine_similarity, threshold)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__filter_similar_embeddings_empty() -> None:
|
||||
assert len(_filter_similar_embeddings([], cosine_similarity, 0.0)) == 0
|
39
tests/unit_tests/test_math_utils.py
Normal file
39
tests/unit_tests/test_math_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Test math utility functions."""
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.math_utils import cosine_similarity
|
||||
|
||||
|
||||
def test_cosine_similarity_zero() -> None:
|
||||
X = np.zeros((3, 3))
|
||||
Y = np.random.random((3, 3))
|
||||
expected = np.zeros((3, 3))
|
||||
actual = cosine_similarity(X, Y)
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
|
||||
def test_cosine_similarity_identity() -> None:
|
||||
X = np.random.random((4, 4))
|
||||
expected = np.ones(4)
|
||||
actual = np.diag(cosine_similarity(X, X))
|
||||
assert np.allclose(expected, actual)
|
||||
|
||||
|
||||
def test_cosine_similarity_empty() -> None:
|
||||
empty_list: List[List[float]] = []
|
||||
assert len(cosine_similarity(empty_list, empty_list)) == 0
|
||||
assert len(cosine_similarity(empty_list, np.random.random((3, 3)))) == 0
|
||||
|
||||
|
||||
def test_cosine_similarity() -> None:
|
||||
X = [[1.0, 2.0, 3.0], [0.0, 1.0, 0.0], [1.0, 2.0, 0.0]]
|
||||
Y = [[0.5, 1.0, 1.5], [1.0, 0.0, 0.0], [2.0, 5.0, 2.0]]
|
||||
expected = [
|
||||
[1.0, 0.26726124, 0.83743579],
|
||||
[0.53452248, 0.0, 0.87038828],
|
||||
[0.5976143, 0.4472136, 0.93419873],
|
||||
]
|
||||
actual = cosine_similarity(X, Y)
|
||||
assert np.allclose(expected, actual)
|
Reference in New Issue
Block a user