From 228ddabc3bd32cbb14b9089984c5d118767dbf28 Mon Sep 17 00:00:00 2001 From: Karim Lalani Date: Thu, 21 Dec 2023 11:04:57 -0600 Subject: [PATCH] community: fix for surrealdb client 0.3.2 update + store and retrieve metadata (#14997) Surrealdb client changes from 0.3.1 to 0.3.2 broke the surrealdb vectore integration. This PR updates the code to work with the updated client. The change is backwards compatible with previous versions of surrealdb client. Also expanded the vector store implementation to store and retrieve metadata that's included with the document object. --- .../vectorstores/surrealdb.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/surrealdb.py b/libs/community/langchain_community/vectorstores/surrealdb.py index a96c2410fe1..773a00cc576 100644 --- a/libs/community/langchain_community/vectorstores/surrealdb.py +++ b/libs/community/langchain_community/vectorstores/surrealdb.py @@ -62,7 +62,7 @@ class SurrealDBStore(VectorStore): self.db = kwargs.pop("db", "database") self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc") self.embedding_function = embedding_function - self.sdb = Surreal() + self.sdb = Surreal(self.dburl) self.kwargs = kwargs async def initialize(self) -> None: @@ -103,8 +103,12 @@ class SurrealDBStore(VectorStore): embeddings = self.embedding_function.embed_documents(list(texts)) ids = [] for idx, text in enumerate(texts): + data = {"text": text, "embedding": embeddings[idx]} + if metadatas is not None and idx < len(metadatas): + data["metadata"] = metadatas[idx] record = await self.sdb.create( - self.collection, {"text": text, "embedding": embeddings[idx]} + self.collection, + data, ) ids.append(record[0]["id"]) return ids @@ -123,7 +127,16 @@ class SurrealDBStore(VectorStore): Returns: List of ids for the newly inserted documents """ - return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs)) + + async def _add_texts( + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> List[str]: + await self.initialize() + return await self.aadd_texts(texts, metadatas, **kwargs) + + return asyncio.run(_add_texts(texts, metadatas, **kwargs)) async def adelete( self, @@ -195,7 +208,7 @@ class SurrealDBStore(VectorStore): "k": k, "score_threshold": kwargs.get("score_threshold", 0), } - query = """select id, text, + query = """select id, text, metadata, vector::similarity::cosine(embedding,{embedding}) as similarity from {collection} where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold} @@ -208,7 +221,10 @@ class SurrealDBStore(VectorStore): return [ ( - Document(page_content=result["text"], metadata={"id": result["id"]}), + Document( + page_content=result["text"], + metadata={"id": result["id"], **result["metadata"]}, + ), result["similarity"], ) for result in results[0]["result"] @@ -401,7 +417,7 @@ class SurrealDBStore(VectorStore): sdb = cls(embedding, **kwargs) await sdb.initialize() - await sdb.aadd_texts(texts) + await sdb.aadd_texts(texts, metadatas, **kwargs) return sdb @classmethod