DocArray as a Retriever (#6031)

## DocArray as a Retriever

[DocArray](https://github.com/docarray/docarray) is an open-source tool
for managing your multi-modal data. It offers flexibility to store and
search through your data using various document index backends. This PR
introduces `DocArrayRetriever` - which works with any available backend
and serves as a retriever for Langchain apps.

Also, I added 2 notebooks:
DocArray Backends - intro to all 5 currently supported backends, how to
initialize, index, and use them as a retriever
DocArray Usage - showcasing what additional search parameters you can
pass to create versatile retrievers

Example:
```python
from docarray.index import InMemoryExactNNIndex
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import DocArrayRetriever


# define document schema
class MyDoc(BaseDoc):
    description: str
    description_embedding: NdArray[1536]


embeddings = OpenAIEmbeddings()
# create documents
descriptions = ["description 1", "description 2"]
desc_embeddings = embeddings.embed_documents(texts=descriptions)
docs = DocList[MyDoc](
    [
        MyDoc(description=desc, description_embedding=embedding)
        for desc, embedding in zip(descriptions, desc_embeddings)
    ]
)

# initialize document index with data
db = InMemoryExactNNIndex[MyDoc](docs)

# create a retriever
retriever = DocArrayRetriever(
    index=db,
    embeddings=embeddings,
    search_field="description_embedding",
    content_field="description",
)

# find the relevant document
doc = retriever.get_relevant_documents("action movies")
print(doc)
```

#### Who can review?

@dev2049

---------

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>
This commit is contained in:
Saba Sturua
2023-06-17 18:09:33 +02:00
committed by GitHub
parent 7bb437146d
commit 427551eabf
6 changed files with 1263 additions and 0 deletions

View File

@@ -0,0 +1,195 @@
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
import numpy as np
import pytest
from docarray import BaseDoc
from docarray.index import (
ElasticDocIndex,
HnswDocumentIndex,
InMemoryExactNNIndex,
QdrantDocumentIndex,
WeaviateDocumentIndex,
)
from docarray.typing import NdArray
from pydantic import Field
from qdrant_client.http import models as rest
from langchain.embeddings import FakeEmbeddings
class MyDoc(BaseDoc):
title: str
title_embedding: NdArray[32] # type: ignore
other_emb: NdArray[32] # type: ignore
year: int
class WeaviateDoc(BaseDoc):
# When initializing the Weaviate index, denote the field
# you want to search on with `is_embedding=True`
title: str
title_embedding: NdArray[32] = Field(is_embedding=True) # type: ignore
other_emb: NdArray[32] # type: ignore
year: int
@pytest.fixture
def init_weaviate() -> (
Generator[
Tuple[WeaviateDocumentIndex[WeaviateDoc], Dict[str, Any], FakeEmbeddings],
None,
None,
]
):
"""
cd tests/integration_tests/vectorstores/docker-compose
docker compose -f weaviate.yml up
"""
embeddings = FakeEmbeddings(size=32)
# initialize WeaviateDocumentIndex
dbconfig = WeaviateDocumentIndex.DBConfig(host="http://localhost:8080")
weaviate_db = WeaviateDocumentIndex[WeaviateDoc](
db_config=dbconfig, index_name="docarray_retriever"
)
# index data
weaviate_db.index(
[
WeaviateDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"path": ["year"], "operator": "LessThanEqual", "valueInt": "90"}
yield weaviate_db, filter_query, embeddings
weaviate_db._client.schema.delete_all()
@pytest.fixture
def init_elastic() -> (
Generator[Tuple[ElasticDocIndex[MyDoc], Dict[str, Any], FakeEmbeddings], None, None]
):
"""
cd tests/integration_tests/vectorstores/docker-compose
docker-compose -f elasticsearch.yml up
"""
embeddings = FakeEmbeddings(size=32)
# initialize ElasticDocIndex
elastic_db = ElasticDocIndex[MyDoc](
hosts="http://localhost:9200", index_name="docarray_retriever"
)
# index data
elastic_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"range": {"year": {"lte": 90}}}
yield elastic_db, filter_query, embeddings
elastic_db._client.indices.delete(index="docarray_retriever")
@pytest.fixture
def init_qdrant() -> Tuple[QdrantDocumentIndex[MyDoc], rest.Filter, FakeEmbeddings]:
embeddings = FakeEmbeddings(size=32)
# initialize QdrantDocumentIndex
qdrant_config = QdrantDocumentIndex.DBConfig(path=":memory:")
qdrant_db = QdrantDocumentIndex[MyDoc](qdrant_config)
# index data
qdrant_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = rest.Filter(
must=[
rest.FieldCondition(
key="year",
range=rest.Range(
gte=10,
lt=90,
),
)
]
)
return qdrant_db, filter_query, embeddings
@pytest.fixture
def init_in_memory() -> (
Tuple[InMemoryExactNNIndex[MyDoc], Dict[str, Any], FakeEmbeddings]
):
embeddings = FakeEmbeddings(size=32)
# initialize InMemoryExactNNIndex
in_memory_db = InMemoryExactNNIndex[MyDoc]()
# index data
in_memory_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"year": {"$lte": 90}}
return in_memory_db, filter_query, embeddings
@pytest.fixture
def init_hnsw(
tmp_path: Path,
) -> Tuple[HnswDocumentIndex[MyDoc], Dict[str, Any], FakeEmbeddings]:
embeddings = FakeEmbeddings(size=32)
# initialize InMemoryExactNNIndex
hnsw_db = HnswDocumentIndex[MyDoc](work_dir=tmp_path)
# index data
hnsw_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"year": {"$lte": 90}}
return hnsw_db, filter_query, embeddings

View File

@@ -0,0 +1,71 @@
from typing import Any
import pytest
from vcr.request import Request
from langchain.retrievers import DocArrayRetriever
from tests.integration_tests.retrievers.docarray.fixtures import ( # noqa: F401
init_elastic,
init_hnsw,
init_in_memory,
init_qdrant,
init_weaviate,
)
@pytest.mark.parametrize(
"backend",
["init_hnsw", "init_in_memory", "init_qdrant", "init_elastic", "init_weaviate"],
)
def test_backends(request: Request, backend: Any) -> None:
index, filter_query, embeddings = request.getfixturevalue(backend)
# create a retriever
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
content_field="title",
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
# create a retriever with filters
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
content_field="title",
filters=filter_query,
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
assert docs[0].metadata["year"] <= 90
# create a retriever with MMR search
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
search_type="mmr",
content_field="title",
filters=filter_query,
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
assert docs[0].metadata["year"] <= 90