mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
LanceDB integration update (#22869)
Added : - [x] relevance search (w/wo scores) - [x] maximal marginal search - [x] image ingestion - [x] filtering support - [x] hybrid search w reranking make test, lint_diff and format checked.
This commit is contained in:
@@ -1,21 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Iterable, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import guard_import
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
DEFAULT_K = 4 # Number of Documents to return.
|
||||
|
||||
|
||||
def import_lancedb() -> Any:
|
||||
"""Import lancedb package."""
|
||||
return guard_import("lancedb")
|
||||
|
||||
|
||||
def to_lance_filter(filter: Dict[str, str]) -> str:
|
||||
"""Converts a dict filter to a LanceDB filter string."""
|
||||
return " AND ".join([f"{k} = '{v}'" for k, v in filter.items()])
|
||||
|
||||
|
||||
class LanceDB(VectorStore):
|
||||
"""`LanceDB` vector store.
|
||||
|
||||
@@ -55,6 +66,11 @@ class LanceDB(VectorStore):
|
||||
api_key: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
mode: Optional[str] = "overwrite",
|
||||
table: Optional[Any] = None,
|
||||
distance: Optional[str] = "l2",
|
||||
reranker: Optional[Any] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
limit: int = DEFAULT_K,
|
||||
):
|
||||
"""Initialize with Lance DB vectorstore"""
|
||||
lancedb = guard_import("lancedb")
|
||||
@@ -62,10 +78,22 @@ class LanceDB(VectorStore):
|
||||
self._vector_key = vector_key
|
||||
self._id_key = id_key
|
||||
self._text_key = text_key
|
||||
self._table_name = table_name
|
||||
self.api_key = api_key or os.getenv("LANCE_API_KEY") if api_key != "" else None
|
||||
self.region = region
|
||||
self.mode = mode
|
||||
self.distance = distance
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self.limit = limit
|
||||
self._fts_index = None
|
||||
|
||||
if isinstance(reranker, lancedb.rerankers.Reranker):
|
||||
self._reranker = reranker
|
||||
elif reranker is None:
|
||||
self._reranker = None
|
||||
else:
|
||||
raise ValueError(
|
||||
"`reranker` has to be a lancedb.rerankers.Reranker object."
|
||||
)
|
||||
|
||||
if isinstance(uri, str) and self.api_key is None:
|
||||
if uri.startswith("db://"):
|
||||
@@ -96,6 +124,52 @@ class LanceDB(VectorStore):
|
||||
"api key provided with local uri.\
|
||||
The data will be stored locally"
|
||||
)
|
||||
if table is not None:
|
||||
try:
|
||||
assert isinstance(
|
||||
table, (lancedb.db.LanceTable, lancedb.remote.table.RemoteTable)
|
||||
)
|
||||
self._table = table
|
||||
self._table_name = (
|
||||
table.name if hasattr(table, "name") else "remote_table"
|
||||
)
|
||||
except AssertionError:
|
||||
raise ValueError(
|
||||
"""`table` has to be a lancedb.db.LanceTable or
|
||||
lancedb.remote.table.RemoteTable object."""
|
||||
)
|
||||
else:
|
||||
self._table = self.get_table(table_name, set_default=True)
|
||||
|
||||
def results_to_docs(self, results: Any, score: bool = False) -> Any:
|
||||
columns = results.schema.names
|
||||
|
||||
if "_distance" in columns:
|
||||
score_col = "_distance"
|
||||
elif "_relevance_score" in columns:
|
||||
score_col = "_relevance_score"
|
||||
else:
|
||||
score_col = None
|
||||
|
||||
if score_col is None or not score:
|
||||
return [
|
||||
Document(
|
||||
page_content=results[self._text_key][idx].as_py(),
|
||||
metadata=results["metadata"][idx].as_py(),
|
||||
)
|
||||
for idx in range(len(results))
|
||||
]
|
||||
elif score_col and score:
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=results[self._text_key][idx].as_py(),
|
||||
metadata=results["metadata"][idx].as_py(),
|
||||
),
|
||||
results[score_col][idx].as_py(),
|
||||
)
|
||||
for idx in range(len(results))
|
||||
]
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
@@ -114,11 +188,11 @@ class LanceDB(VectorStore):
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of ids to associate with the texts.
|
||||
ids: Optional list of ids to associate with the texts.
|
||||
|
||||
Returns:
|
||||
List of ids of the added texts.
|
||||
"""
|
||||
# Embed texts and create documents
|
||||
docs = []
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore
|
||||
@@ -134,14 +208,19 @@ class LanceDB(VectorStore):
|
||||
}
|
||||
)
|
||||
|
||||
if self._table_name in self._connection.table_names():
|
||||
tbl = self._connection.open_table(self._table_name)
|
||||
tbl = self.get_table()
|
||||
|
||||
if tbl is None:
|
||||
tbl = self._connection.create_table(self._table_name, data=docs)
|
||||
self._table = tbl
|
||||
else:
|
||||
if self.api_key is None:
|
||||
tbl.add(docs, mode=self.mode)
|
||||
else:
|
||||
tbl.add(docs)
|
||||
else:
|
||||
self._connection.create_table(self._table_name, data=docs)
|
||||
|
||||
self._fts_index = None
|
||||
|
||||
return ids
|
||||
|
||||
def get_table(
|
||||
@@ -164,14 +243,18 @@ class LanceDB(VectorStore):
|
||||
|
||||
"""
|
||||
if name is not None:
|
||||
try:
|
||||
if set_default:
|
||||
self._table_name = name
|
||||
return self._connection.open_table(name)
|
||||
except Exception:
|
||||
raise ValueError(f"Table {name} not found in the database")
|
||||
if set_default:
|
||||
self._table_name = name
|
||||
_name = self._table_name
|
||||
else:
|
||||
_name = name
|
||||
else:
|
||||
return self._connection.open_table(self._table_name)
|
||||
_name = self._table_name
|
||||
|
||||
try:
|
||||
return self._connection.open_table(_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
@@ -181,6 +264,7 @@ class LanceDB(VectorStore):
|
||||
num_sub_vectors: Optional[int] = 96,
|
||||
index_cache_size: Optional[int] = None,
|
||||
metric: Optional[str] = "L2",
|
||||
name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a scalar(for non-vector cols) or a vector index on a table.
|
||||
@@ -191,11 +275,15 @@ class LanceDB(VectorStore):
|
||||
col_name: Provide if you want to create index on a non-vector column.
|
||||
metric: Provide the metric to use for vector index. Defaults to 'L2'
|
||||
choice of metrics: 'L2', 'dot', 'cosine'
|
||||
num_partitions: Number of partitions to use for the index. Defaults to 256.
|
||||
num_sub_vectors: Number of sub-vectors to use for the index. Defaults to 96.
|
||||
index_cache_size: Size of the index cache. Defaults to None.
|
||||
name: Name of the table to create index on. Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
tbl = self.get_table()
|
||||
tbl = self.get_table(name)
|
||||
|
||||
if vector_col:
|
||||
tbl.create_index(
|
||||
@@ -210,8 +298,205 @@ class LanceDB(VectorStore):
|
||||
else:
|
||||
raise ValueError("Provide either vector_col or col_name")
|
||||
|
||||
def encode_image(self, uri: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
with open(uri, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def add_images(
|
||||
self,
|
||||
uris: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more images through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
uris List[str]: File path to the image.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added images.
|
||||
"""
|
||||
tbl = self.get_table()
|
||||
|
||||
# Map from uris to b64 encoded strings
|
||||
b64_texts = [self.encode_image(uri=uri) for uri in uris]
|
||||
# Populate IDs
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in uris]
|
||||
embeddings = None
|
||||
# Set embeddings
|
||||
if self._embedding is not None and hasattr(self._embedding, "embed_image"):
|
||||
embeddings = self._embedding.embed_image(uris=uris)
|
||||
else:
|
||||
raise ValueError(
|
||||
"embedding object should be provided and must have embed_image method."
|
||||
)
|
||||
|
||||
data = []
|
||||
for idx, emb in enumerate(embeddings):
|
||||
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
||||
data.append(
|
||||
{
|
||||
self._vector_key: emb,
|
||||
self._id_key: ids[idx],
|
||||
self._text_key: b64_texts[idx],
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
if tbl is None:
|
||||
tbl = self._connection.create_table(self._table_name, data=data)
|
||||
self._table = tbl
|
||||
else:
|
||||
tbl.add(data)
|
||||
|
||||
return ids
|
||||
|
||||
def _query(
|
||||
self,
|
||||
query: Any,
|
||||
k: Optional[int] = None,
|
||||
filter: Optional[Any] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if k is None:
|
||||
k = self.limit
|
||||
tbl = self.get_table(name)
|
||||
if isinstance(filter, dict):
|
||||
filter = to_lance_filter(filter)
|
||||
|
||||
prefilter = kwargs.get("prefilter", False)
|
||||
query_type = kwargs.get("query_type", "vector")
|
||||
|
||||
lance_query = (
|
||||
tbl.search(query=query, vector_column_name=self._vector_key)
|
||||
.limit(k)
|
||||
.where(filter, prefilter=prefilter)
|
||||
)
|
||||
if query_type == "hybrid" and self._reranker is not None:
|
||||
lance_query.rerank(reranker=self._reranker)
|
||||
|
||||
docs = lance_query.to_arrow()
|
||||
if len(docs) == 0:
|
||||
warnings.warn("No results found for the query.")
|
||||
return docs
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
if self.distance == "cosine":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance == "l2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
elif self.distance == "ip":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"No supported normalization function"
|
||||
f" for distance metric of type: {self.distance}."
|
||||
"Consider providing relevance_score_fn to Chroma constructor."
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: Optional[int] = None,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Return documents most similar to the query vector.
|
||||
"""
|
||||
if k is None:
|
||||
k = self.limit
|
||||
|
||||
res = self._query(embedding, k, filter=filter, name=name, **kwargs)
|
||||
return self.results_to_docs(res, score=kwargs.pop("score", False))
|
||||
|
||||
def similarity_search_by_vector_with_relevance_scores(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: Optional[int] = None,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Return documents most similar to the query vector with relevance scores.
|
||||
"""
|
||||
if k is None:
|
||||
k = self.limit
|
||||
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
docs_and_scores = self.similarity_search_by_vector(
|
||||
embedding, k, score=True, **kwargs
|
||||
)
|
||||
return [
|
||||
(doc, relevance_score_fn(float(score))) for doc, score in docs_and_scores
|
||||
]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: Optional[int] = None,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Return documents most similar to the query with relevance scores."""
|
||||
if k is None:
|
||||
k = self.limit
|
||||
|
||||
score = kwargs.get("score", True)
|
||||
name = kwargs.get("name", None)
|
||||
query_type = kwargs.get("query_type", "vector")
|
||||
|
||||
if self._embedding is None:
|
||||
raise ValueError("search needs an emmbedding function to be specified.")
|
||||
|
||||
if query_type == "fts" or query_type == "hybrid":
|
||||
if self.api_key is None and self._fts_index is None:
|
||||
tbl = self.get_table(name)
|
||||
self._fts_index = tbl.create_fts_index(self._text_key, replace=True)
|
||||
|
||||
if query_type == "hybrid":
|
||||
embedding = self._embedding.embed_query(query)
|
||||
_query = (embedding, query)
|
||||
else:
|
||||
_query = query # type: ignore
|
||||
|
||||
res = self._query(_query, k, filter=filter, name=name, **kwargs)
|
||||
return self.results_to_docs(res, score=score)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Full text/ Hybrid search is not supported in LanceDB Cloud yet."
|
||||
)
|
||||
else:
|
||||
embedding = self._embedding.embed_query(query)
|
||||
res = self._query(embedding, k, filter=filter, **kwargs)
|
||||
return self.results_to_docs(res, score=score)
|
||||
|
||||
def similarity_search(
|
||||
self, query: str, k: int = 4, name: Optional[str] = None, **kwargs: Any
|
||||
self,
|
||||
query: str,
|
||||
k: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
filter: Optional[Any] = None,
|
||||
fts: Optional[bool] = False,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return documents most similar to the query
|
||||
|
||||
@@ -227,60 +512,118 @@ class LanceDB(VectorStore):
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Retrieve documents with filtering based on a metadata file_type
|
||||
vector_store.as_retriever(search_kwargs={"k": 4, "filter":{
|
||||
'sql_filter':"file_type='notice'",
|
||||
'prefilter': True
|
||||
}
|
||||
})
|
||||
|
||||
# Retrieve documents with filtering on a specific file name
|
||||
vector_store.as_retriever(search_kwargs={"k": 4, "filter":{
|
||||
'sql_filter':"source='my-file.txt'",
|
||||
'prefilter': True
|
||||
}
|
||||
})
|
||||
"""
|
||||
embedding = self._embedding.embed_query(query) # type: ignore
|
||||
tbl = self.get_table(name)
|
||||
filters = kwargs.pop("filter", {})
|
||||
sql_filter = filters.pop("sql_filter", None)
|
||||
prefilter = filters.pop("prefilter", False)
|
||||
docs = (
|
||||
tbl.search(embedding, vector_column_name=self._vector_key)
|
||||
.where(sql_filter, prefilter=prefilter)
|
||||
.limit(k)
|
||||
.to_arrow()
|
||||
res = self.similarity_search_with_score(
|
||||
query=query, k=k, name=name, filter=filter, fts=fts, score=False, **kwargs
|
||||
)
|
||||
columns = docs.schema.names
|
||||
return [
|
||||
Document(
|
||||
page_content=docs[self._text_key][idx].as_py(),
|
||||
metadata={
|
||||
col: docs[col][idx].as_py()
|
||||
for col in columns
|
||||
if col != self._text_key
|
||||
},
|
||||
return res
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: Optional[int] = None,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if k is None:
|
||||
k = self.limit
|
||||
|
||||
if self._embedding is None:
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function on" "creation."
|
||||
)
|
||||
for idx in range(len(docs))
|
||||
]
|
||||
|
||||
embedding = self._embedding.embed_query(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
)
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: Optional[int] = None,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
|
||||
results = self._query(
|
||||
query=embedding,
|
||||
k=fetch_k,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
results["vector"].to_pylist(),
|
||||
k=k or self.limit,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
candidates = self.results_to_docs(results)
|
||||
|
||||
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[LanceDB],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
connection: Any = None,
|
||||
connection: Optional[Any] = None,
|
||||
vector_key: Optional[str] = "vector",
|
||||
id_key: Optional[str] = "id",
|
||||
text_key: Optional[str] = "text",
|
||||
table_name: Optional[str] = "vectorstore",
|
||||
api_key: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
mode: Optional[str] = "overwrite",
|
||||
distance: Optional[str] = "l2",
|
||||
reranker: Optional[Any] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LanceDB:
|
||||
instance = LanceDB(
|
||||
@@ -290,8 +633,15 @@ class LanceDB(VectorStore):
|
||||
id_key=id_key,
|
||||
text_key=text_key,
|
||||
table_name=table_name,
|
||||
api_key=api_key,
|
||||
region=region,
|
||||
mode=mode,
|
||||
distance=distance,
|
||||
reranker=reranker,
|
||||
relevance_score_fn=relevance_score_fn,
|
||||
**kwargs,
|
||||
)
|
||||
instance.add_texts(texts, metadatas=metadatas, **kwargs)
|
||||
instance.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
return instance
|
||||
|
||||
|
Reference in New Issue
Block a user