mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 05:25:07 +00:00
DocArray as a Retriever (#6031)
## DocArray as a Retriever [DocArray](https://github.com/docarray/docarray) is an open-source tool for managing your multi-modal data. It offers flexibility to store and search through your data using various document index backends. This PR introduces `DocArrayRetriever` - which works with any available backend and serves as a retriever for Langchain apps. Also, I added 2 notebooks: DocArray Backends - intro to all 5 currently supported backends, how to initialize, index, and use them as a retriever DocArray Usage - showcasing what additional search parameters you can pass to create versatile retrievers Example: ```python from docarray.index import InMemoryExactNNIndex from docarray import BaseDoc, DocList from docarray.typing import NdArray from langchain.embeddings.openai import OpenAIEmbeddings from langchain.retrievers import DocArrayRetriever # define document schema class MyDoc(BaseDoc): description: str description_embedding: NdArray[1536] embeddings = OpenAIEmbeddings() # create documents descriptions = ["description 1", "description 2"] desc_embeddings = embeddings.embed_documents(texts=descriptions) docs = DocList[MyDoc]( [ MyDoc(description=desc, description_embedding=embedding) for desc, embedding in zip(descriptions, desc_embeddings) ] ) # initialize document index with data db = InMemoryExactNNIndex[MyDoc](docs) # create a retriever retriever = DocArrayRetriever( index=db, embeddings=embeddings, search_field="description_embedding", content_field="description", ) # find the relevant document doc = retriever.get_relevant_documents("action movies") print(doc) ``` #### Who can review? @dev2049 --------- Signed-off-by: jupyterjazz <saba.sturua@jina.ai>
This commit is contained in:
195
tests/integration_tests/retrievers/docarray/fixtures.py
Normal file
195
tests/integration_tests/retrievers/docarray/fixtures.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from docarray import BaseDoc
|
||||
from docarray.index import (
|
||||
ElasticDocIndex,
|
||||
HnswDocumentIndex,
|
||||
InMemoryExactNNIndex,
|
||||
QdrantDocumentIndex,
|
||||
WeaviateDocumentIndex,
|
||||
)
|
||||
from docarray.typing import NdArray
|
||||
from pydantic import Field
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
from langchain.embeddings import FakeEmbeddings
|
||||
|
||||
|
||||
class MyDoc(BaseDoc):
|
||||
title: str
|
||||
title_embedding: NdArray[32] # type: ignore
|
||||
other_emb: NdArray[32] # type: ignore
|
||||
year: int
|
||||
|
||||
|
||||
class WeaviateDoc(BaseDoc):
|
||||
# When initializing the Weaviate index, denote the field
|
||||
# you want to search on with `is_embedding=True`
|
||||
title: str
|
||||
title_embedding: NdArray[32] = Field(is_embedding=True) # type: ignore
|
||||
other_emb: NdArray[32] # type: ignore
|
||||
year: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_weaviate() -> (
|
||||
Generator[
|
||||
Tuple[WeaviateDocumentIndex[WeaviateDoc], Dict[str, Any], FakeEmbeddings],
|
||||
None,
|
||||
None,
|
||||
]
|
||||
):
|
||||
"""
|
||||
cd tests/integration_tests/vectorstores/docker-compose
|
||||
docker compose -f weaviate.yml up
|
||||
"""
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
|
||||
# initialize WeaviateDocumentIndex
|
||||
dbconfig = WeaviateDocumentIndex.DBConfig(host="http://localhost:8080")
|
||||
weaviate_db = WeaviateDocumentIndex[WeaviateDoc](
|
||||
db_config=dbconfig, index_name="docarray_retriever"
|
||||
)
|
||||
|
||||
# index data
|
||||
weaviate_db.index(
|
||||
[
|
||||
WeaviateDoc(
|
||||
title=f"My document {i}",
|
||||
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
|
||||
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
|
||||
year=i,
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
)
|
||||
# build a filter query
|
||||
filter_query = {"path": ["year"], "operator": "LessThanEqual", "valueInt": "90"}
|
||||
|
||||
yield weaviate_db, filter_query, embeddings
|
||||
|
||||
weaviate_db._client.schema.delete_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_elastic() -> (
|
||||
Generator[Tuple[ElasticDocIndex[MyDoc], Dict[str, Any], FakeEmbeddings], None, None]
|
||||
):
|
||||
"""
|
||||
cd tests/integration_tests/vectorstores/docker-compose
|
||||
docker-compose -f elasticsearch.yml up
|
||||
"""
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
|
||||
# initialize ElasticDocIndex
|
||||
elastic_db = ElasticDocIndex[MyDoc](
|
||||
hosts="http://localhost:9200", index_name="docarray_retriever"
|
||||
)
|
||||
# index data
|
||||
elastic_db.index(
|
||||
[
|
||||
MyDoc(
|
||||
title=f"My document {i}",
|
||||
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
|
||||
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
|
||||
year=i,
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
)
|
||||
# build a filter query
|
||||
filter_query = {"range": {"year": {"lte": 90}}}
|
||||
|
||||
yield elastic_db, filter_query, embeddings
|
||||
|
||||
elastic_db._client.indices.delete(index="docarray_retriever")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_qdrant() -> Tuple[QdrantDocumentIndex[MyDoc], rest.Filter, FakeEmbeddings]:
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
|
||||
# initialize QdrantDocumentIndex
|
||||
qdrant_config = QdrantDocumentIndex.DBConfig(path=":memory:")
|
||||
qdrant_db = QdrantDocumentIndex[MyDoc](qdrant_config)
|
||||
# index data
|
||||
qdrant_db.index(
|
||||
[
|
||||
MyDoc(
|
||||
title=f"My document {i}",
|
||||
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
|
||||
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
|
||||
year=i,
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
)
|
||||
# build a filter query
|
||||
filter_query = rest.Filter(
|
||||
must=[
|
||||
rest.FieldCondition(
|
||||
key="year",
|
||||
range=rest.Range(
|
||||
gte=10,
|
||||
lt=90,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return qdrant_db, filter_query, embeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_in_memory() -> (
|
||||
Tuple[InMemoryExactNNIndex[MyDoc], Dict[str, Any], FakeEmbeddings]
|
||||
):
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
|
||||
# initialize InMemoryExactNNIndex
|
||||
in_memory_db = InMemoryExactNNIndex[MyDoc]()
|
||||
# index data
|
||||
in_memory_db.index(
|
||||
[
|
||||
MyDoc(
|
||||
title=f"My document {i}",
|
||||
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
|
||||
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
|
||||
year=i,
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
)
|
||||
# build a filter query
|
||||
filter_query = {"year": {"$lte": 90}}
|
||||
|
||||
return in_memory_db, filter_query, embeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def init_hnsw(
|
||||
tmp_path: Path,
|
||||
) -> Tuple[HnswDocumentIndex[MyDoc], Dict[str, Any], FakeEmbeddings]:
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
|
||||
# initialize InMemoryExactNNIndex
|
||||
hnsw_db = HnswDocumentIndex[MyDoc](work_dir=tmp_path)
|
||||
# index data
|
||||
hnsw_db.index(
|
||||
[
|
||||
MyDoc(
|
||||
title=f"My document {i}",
|
||||
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
|
||||
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
|
||||
year=i,
|
||||
)
|
||||
for i in range(100)
|
||||
]
|
||||
)
|
||||
# build a filter query
|
||||
filter_query = {"year": {"$lte": 90}}
|
||||
|
||||
return hnsw_db, filter_query, embeddings
|
71
tests/integration_tests/retrievers/docarray/test_backends.py
Normal file
71
tests/integration_tests/retrievers/docarray/test_backends.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from vcr.request import Request
|
||||
|
||||
from langchain.retrievers import DocArrayRetriever
|
||||
from tests.integration_tests.retrievers.docarray.fixtures import ( # noqa: F401
|
||||
init_elastic,
|
||||
init_hnsw,
|
||||
init_in_memory,
|
||||
init_qdrant,
|
||||
init_weaviate,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
["init_hnsw", "init_in_memory", "init_qdrant", "init_elastic", "init_weaviate"],
|
||||
)
|
||||
def test_backends(request: Request, backend: Any) -> None:
|
||||
index, filter_query, embeddings = request.getfixturevalue(backend)
|
||||
|
||||
# create a retriever
|
||||
retriever = DocArrayRetriever(
|
||||
index=index,
|
||||
embeddings=embeddings,
|
||||
search_field="title_embedding",
|
||||
content_field="title",
|
||||
)
|
||||
|
||||
docs = retriever.get_relevant_documents("my docs")
|
||||
|
||||
assert len(docs) == 1
|
||||
assert "My document" in docs[0].page_content
|
||||
assert "id" in docs[0].metadata and "year" in docs[0].metadata
|
||||
assert "other_emb" not in docs[0].metadata
|
||||
|
||||
# create a retriever with filters
|
||||
retriever = DocArrayRetriever(
|
||||
index=index,
|
||||
embeddings=embeddings,
|
||||
search_field="title_embedding",
|
||||
content_field="title",
|
||||
filters=filter_query,
|
||||
)
|
||||
|
||||
docs = retriever.get_relevant_documents("my docs")
|
||||
|
||||
assert len(docs) == 1
|
||||
assert "My document" in docs[0].page_content
|
||||
assert "id" in docs[0].metadata and "year" in docs[0].metadata
|
||||
assert "other_emb" not in docs[0].metadata
|
||||
assert docs[0].metadata["year"] <= 90
|
||||
|
||||
# create a retriever with MMR search
|
||||
retriever = DocArrayRetriever(
|
||||
index=index,
|
||||
embeddings=embeddings,
|
||||
search_field="title_embedding",
|
||||
search_type="mmr",
|
||||
content_field="title",
|
||||
filters=filter_query,
|
||||
)
|
||||
|
||||
docs = retriever.get_relevant_documents("my docs")
|
||||
|
||||
assert len(docs) == 1
|
||||
assert "My document" in docs[0].page_content
|
||||
assert "id" in docs[0].metadata and "year" in docs[0].metadata
|
||||
assert "other_emb" not in docs[0].metadata
|
||||
assert docs[0].metadata["year"] <= 90
|
Reference in New Issue
Block a user