langchain/libs/community/langchain_community/retrievers/docarray.py
Eugene Yurtsev bf5193bb99
community[patch]: Upgrade pydantic extra (#25185)
Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.

This works correctly also in pydantic v1.

```python
from pydantic.v1 import BaseModel

class Foo(BaseModel, extra="forbid"):
    x: int

Foo(x=5, y=1)
```

And 


```python
from pydantic.v1 import BaseModel

class Foo(BaseModel):
    x: int

    class Config:
      extra = "forbid"

Foo(x=5, y=1)
```


## Enum -> literal using grit pattern:

```
engine marzano(0.1)
language python
or {
    `extra=Extra.allow` => `extra="allow"`,
    `extra=Extra.forbid` => `extra="forbid"`,
    `extra=Extra.ignore` => `extra="ignore"`
}
```

Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)


## Sort attributes in Config:

```
engine marzano(0.1)
language python


function sort($values) js {
    return $values.text.split(',').sort().join("\n");
}


class_definition($name, $body) as $C where {
    $name <: `Config`,
    $body <: block($statements),
    $values = [],
    $statements <: some bubble($values) assignment() as $A where {
        $values += $A
    },
    $body => sort($values),
}

```
2024-08-08 17:20:39 +00:00

206 lines
6.6 KiB
Python

from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_community.vectorstores.utils import maximal_marginal_relevance
class SearchType(str, Enum):
"""Enumerator of the types of search to perform."""
similarity = "similarity"
mmr = "mmr"
class DocArrayRetriever(BaseRetriever):
"""`DocArray Document Indices` retriever.
Currently, it supports 5 backends:
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
ElasticDocIndex, and WeaviateDocumentIndex.
Args:
index: One of the above-mentioned index instances
embeddings: Embedding model to represent text as vectors
search_field: Field to consider for searching in the documents.
Should be an embedding/vector/tensor.
content_field: Field that represents the main content in your document schema.
Will be used as a `page_content`. Everything else will go into `metadata`.
search_type: Type of search to perform (similarity / mmr)
filters: Filters applied for document retrieval.
top_k: Number of documents to return
"""
index: Any
embeddings: Embeddings
search_field: str
content_field: str
search_type: SearchType = SearchType.similarity
top_k: int = 1
filters: Optional[Any] = None
class Config:
arbitrary_types_allowed = True
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
query_emb = np.array(self.embeddings.embed_query(query))
if self.search_type == SearchType.similarity:
results = self._similarity_search(query_emb)
elif self.search_type == SearchType.mmr:
results = self._mmr_search(query_emb)
else:
raise ValueError(
f"Search type {self.search_type} does not exist. "
f"Choose either 'similarity' or 'mmr'."
)
return results
def _search(
self, query_emb: np.ndarray, top_k: int
) -> List[Union[Dict[str, Any], Any]]:
"""
Perform a search using the query embedding and return top_k documents.
Args:
query_emb: Query represented as an embedding
top_k: Number of documents to return
Returns:
A list of top_k documents matching the query
"""
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
filter_args = {}
search_field = self.search_field
if isinstance(self.index, WeaviateDocumentIndex):
filter_args["where_filter"] = self.filters
search_field = ""
elif isinstance(self.index, ElasticDocIndex):
filter_args["query"] = self.filters
else:
filter_args["filter_query"] = self.filters
if self.filters:
query = (
self.index.build_query() # get empty query object
.find(
query=query_emb, search_field=search_field
) # add vector similarity search
.filter(**filter_args) # add filter search
.build(limit=top_k) # build the query
)
# execute the combined query and return the results
docs = self.index.execute_query(query)
if hasattr(docs, "documents"):
docs = docs.documents
docs = docs[:top_k]
else:
docs = self.index.find(
query=query_emb, search_field=search_field, limit=top_k
).documents
return docs
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a similarity search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of documents most similar to the query
"""
docs = self._search(query_emb=query_emb, top_k=self.top_k)
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
return results
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a maximal marginal relevance (mmr) search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of diverse documents related to the query
"""
docs = self._search(query_emb=query_emb, top_k=20)
mmr_selected = maximal_marginal_relevance(
query_emb,
[
doc[self.search_field]
if isinstance(doc, dict)
else getattr(doc, self.search_field)
for doc in docs
],
k=self.top_k,
)
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
return results
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
"""
Convert a DocArray document (which also might be a dict)
to a langchain document format.
DocArray document can contain arbitrary fields, so the mapping is done
in the following way:
page_content <-> content_field
metadata <-> all other fields excluding
tensors and embeddings (so float, int, string)
Args:
doc: DocArray document
Returns:
Document in langchain format
Raises:
ValueError: If the document doesn't contain the content field
"""
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
if self.content_field not in fields:
raise ValueError(
f"Document does not contain the content field - {self.content_field}."
)
lc_doc = Document(
page_content=doc[self.content_field]
if isinstance(doc, dict)
else getattr(doc, self.content_field)
)
for name in fields:
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
if (
isinstance(value, (str, int, float, bool))
and name != self.content_field
):
lc_doc.metadata[name] = value
return lc_doc