From 7d451d004106aa949d2e86052f9cfd01eb207df0 Mon Sep 17 00:00:00 2001 From: Raghav Dixit <34462078+raghavdixit99@users.noreply.github.com> Date: Thu, 2 May 2024 13:06:39 -0400 Subject: [PATCH] community[patch]: Update lancedb.py (#21192) very minor update in LanceDB integration, 'metric' argument was missing. --- .../vectorstores/lancedb.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/lancedb.py b/libs/community/langchain_community/vectorstores/lancedb.py index 2d2f859b766..671ac47efa0 100644 --- a/libs/community/langchain_community/vectorstores/lancedb.py +++ b/libs/community/langchain_community/vectorstores/lancedb.py @@ -151,10 +151,30 @@ class LanceDB(VectorStore): self._connection.create_table(self._table_name, data=docs) return ids - def get_table(self, name: Optional[str] = None) -> Any: + def get_table( + self, name: Optional[str] = None, set_default: Optional[bool] = False + ) -> Any: + """ + Fetches a table object from the database. + + Args: + name (str, optional): The name of the table to fetch. Defaults to None + and fetches current table object. + set_default (bool, optional): Sets fetched table as the default table. + Defaults to False. + + Returns: + Any: The fetched table object. + + Raises: + ValueError: If the specified table is not found in the database. + + """ if name is not None: try: - self._connection.open_table(name) + if set_default: + self._table_name = name + return self._connection.open_table(name) except Exception: raise ValueError(f"Table {name} not found in the database") else: @@ -167,6 +187,7 @@ class LanceDB(VectorStore): num_partitions: Optional[int] = 256, num_sub_vectors: Optional[int] = 96, index_cache_size: Optional[int] = None, + metric: Optional[str] = "L2", ) -> None: """ Create a scalar(for non-vector cols) or a vector index on a table. @@ -181,15 +202,18 @@ class LanceDB(VectorStore): Returns: None """ + tbl = self.get_table() + if vector_col: - self._connection.create_index( + tbl.create_index( + metric=metric, vector_column_name=vector_col, num_partitions=num_partitions, num_sub_vectors=num_sub_vectors, index_cache_size=index_cache_size, ) elif col_name: - self._connection.create_scalar_index(col_name) + tbl.create_scalar_index(col_name) else: raise ValueError("Provide either vector_col or col_name")