mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
community: fix for surrealdb client 0.3.2 update + store and retrieve metadata (#14997)
Surrealdb client changes from 0.3.1 to 0.3.2 broke the surrealdb vectore integration. This PR updates the code to work with the updated client. The change is backwards compatible with previous versions of surrealdb client. Also expanded the vector store implementation to store and retrieve metadata that's included with the document object.
This commit is contained in:
parent
c7be59c122
commit
228ddabc3b
@ -62,7 +62,7 @@ class SurrealDBStore(VectorStore):
|
||||
self.db = kwargs.pop("db", "database")
|
||||
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
|
||||
self.embedding_function = embedding_function
|
||||
self.sdb = Surreal()
|
||||
self.sdb = Surreal(self.dburl)
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@ -103,8 +103,12 @@ class SurrealDBStore(VectorStore):
|
||||
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||
ids = []
|
||||
for idx, text in enumerate(texts):
|
||||
data = {"text": text, "embedding": embeddings[idx]}
|
||||
if metadatas is not None and idx < len(metadatas):
|
||||
data["metadata"] = metadatas[idx]
|
||||
record = await self.sdb.create(
|
||||
self.collection, {"text": text, "embedding": embeddings[idx]}
|
||||
self.collection,
|
||||
data,
|
||||
)
|
||||
ids.append(record[0]["id"])
|
||||
return ids
|
||||
@ -123,7 +127,16 @@ class SurrealDBStore(VectorStore):
|
||||
Returns:
|
||||
List of ids for the newly inserted documents
|
||||
"""
|
||||
return asyncio.run(self.aadd_texts(texts, metadatas, **kwargs))
|
||||
|
||||
async def _add_texts(
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
await self.initialize()
|
||||
return await self.aadd_texts(texts, metadatas, **kwargs)
|
||||
|
||||
return asyncio.run(_add_texts(texts, metadatas, **kwargs))
|
||||
|
||||
async def adelete(
|
||||
self,
|
||||
@ -195,7 +208,7 @@ class SurrealDBStore(VectorStore):
|
||||
"k": k,
|
||||
"score_threshold": kwargs.get("score_threshold", 0),
|
||||
}
|
||||
query = """select id, text,
|
||||
query = """select id, text, metadata,
|
||||
vector::similarity::cosine(embedding,{embedding}) as similarity
|
||||
from {collection}
|
||||
where vector::similarity::cosine(embedding,{embedding}) >= {score_threshold}
|
||||
@ -208,7 +221,10 @@ class SurrealDBStore(VectorStore):
|
||||
|
||||
return [
|
||||
(
|
||||
Document(page_content=result["text"], metadata={"id": result["id"]}),
|
||||
Document(
|
||||
page_content=result["text"],
|
||||
metadata={"id": result["id"], **result["metadata"]},
|
||||
),
|
||||
result["similarity"],
|
||||
)
|
||||
for result in results[0]["result"]
|
||||
@ -401,7 +417,7 @@ class SurrealDBStore(VectorStore):
|
||||
|
||||
sdb = cls(embedding, **kwargs)
|
||||
await sdb.initialize()
|
||||
await sdb.aadd_texts(texts)
|
||||
await sdb.aadd_texts(texts, metadatas, **kwargs)
|
||||
return sdb
|
||||
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user