community[patch]: vectorstores import update (#21169)

Issue: we have several helper functions to import third-party libraries
like lancedb.import_lancedb in
[community.vectorstores](https://api.python.langchain.com/en/latest/vectorstores/langchain_community.vectorstores.lancedb.import_lancedb.html#langchain_community.vectorstores.lancedb.import_lancedb).
And we have core.utils.utils.guard_import that works exactly for this
purpose.
The import_<package> functions work inconsistently and rather be private
functions.
Change: replaced these functions with the guard_import function.

Related to #21133
This commit is contained in:
Leonid Ganeline
2024-05-13 07:45:31 -07:00
committed by GitHub
parent 3003363605
commit 500569da48
5 changed files with 34 additions and 54 deletions

View File

@@ -1,4 +1,5 @@
"""Wrapper around TileDB vector database."""
from __future__ import annotations
import pickle
@@ -9,6 +10,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
@@ -24,16 +26,10 @@ MAX_FLOAT = sys.float_info.max
def dependable_tiledb_import() -> Any:
"""Import tiledb-vector-search if available, otherwise raise error."""
try:
import tiledb as tiledb
import tiledb.vector_search as tiledb_vs
except ImportError:
raise ImportError(
"Could not import tiledb-vector-search python package. "
"Please install it with `conda install -c tiledb tiledb-vector-search` "
"or `pip install tiledb-vector-search`"
)
return tiledb_vs, tiledb
return (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
def get_vector_index_uri_from_group(group: Any) -> str:
@@ -115,7 +111,10 @@ class TileDB(VectorStore):
self.metric = metric
self.config = config
tiledb_vs, tiledb = dependable_tiledb_import()
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
with tiledb.scope_ctx(ctx_or_config=config):
index_group = tiledb.Group(self.index_uri, "r")
self.vector_index_uri = (
@@ -173,7 +172,7 @@ class TileDB(VectorStore):
Returns:
List of Documents and scores.
"""
tiledb_vs, tiledb = dependable_tiledb_import()
tiledb = guard_import("tiledb")
docs = []
docs_array = tiledb.open(
self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
@@ -477,7 +476,10 @@ class TileDB(VectorStore):
metadatas: bool = True,
config: Optional[Mapping[str, Any]] = None,
) -> None:
tiledb_vs, tiledb = dependable_tiledb_import()
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
with tiledb.scope_ctx(ctx_or_config=config):
try:
tiledb.group_create(index_uri)
@@ -550,7 +552,10 @@ class TileDB(VectorStore):
f"Expected one of {list(INDEX_METRICS)}"
)
)
tiledb_vs, tiledb = dependable_tiledb_import()
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
input_vectors = np.array(embeddings).astype(np.float32)
cls.create(
index_uri=index_uri,
@@ -646,7 +651,7 @@ class TileDB(VectorStore):
Returns:
List of ids from adding the texts into the vectorstore.
"""
tiledb_vs, tiledb = dependable_tiledb_import()
tiledb = guard_import("tiledb")
embeddings = self.embedding.embed_documents(list(texts))
if ids is None:
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]