mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 04:28:58 +00:00
community[minor]: Added document loader for SurrealDB (#15995)
Added a simple document loader to work with SurrealDB.
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SurrealDBLoader(BaseLoader):
|
||||
"""Load SurrealDB documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_criteria: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
try:
|
||||
from surrealdb import Surreal
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"""Cannot import from surrealdb.
|
||||
please install with `pip install surrealdb`."""
|
||||
) from e
|
||||
|
||||
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
|
||||
|
||||
if self.dburl[0:2] == "ws":
|
||||
self.sdb = Surreal(self.dburl)
|
||||
else:
|
||||
raise ValueError("Only websocket connections are supported at this time.")
|
||||
|
||||
self.filter_criteria = filter_criteria or {}
|
||||
|
||||
if "table" in self.filter_criteria:
|
||||
raise ValueError(
|
||||
"key `table` is not a valid criteria for `filter_criteria` argument."
|
||||
)
|
||||
|
||||
self.ns = kwargs.pop("ns", "langchain")
|
||||
self.db = kwargs.pop("db", "database")
|
||||
self.table = kwargs.pop("table", "documents")
|
||||
self.sdb = Surreal(self.dburl)
|
||||
self.kwargs = kwargs
|
||||
|
||||
asyncio.run(self.initialize())
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize connection to surrealdb database
|
||||
and authenticate if credentials are provided
|
||||
"""
|
||||
await self.sdb.connect()
|
||||
if "db_user" in self.kwargs and "db_pass" in self.kwargs:
|
||||
user = self.kwargs.get("db_user")
|
||||
password = self.kwargs.get("db_pass")
|
||||
await self.sdb.signin({"user": user, "pass": password})
|
||||
|
||||
await self.sdb.use(self.ns, self.db)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
async def _load() -> List[Document]:
|
||||
await self.initialize()
|
||||
return await self.aload()
|
||||
|
||||
return asyncio.run(_load())
|
||||
|
||||
async def aload(self) -> List[Document]:
|
||||
"""Load data into Document objects."""
|
||||
|
||||
query = "SELECT * FROM type::table($table)"
|
||||
if self.filter_criteria is not None and len(self.filter_criteria) > 0:
|
||||
query += " WHERE "
|
||||
for idx, key in enumerate(self.filter_criteria):
|
||||
query += f""" {"AND" if idx > 0 else ""} {key} = ${key}"""
|
||||
|
||||
metadata = {
|
||||
"ns": self.ns,
|
||||
"db": self.db,
|
||||
"table": self.table,
|
||||
}
|
||||
results = await self.sdb.query(
|
||||
query, {"table": self.table, **self.filter_criteria}
|
||||
)
|
||||
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=json.dumps(result),
|
||||
metadata={"id": result["id"], **result["metadata"], **metadata},
|
||||
)
|
||||
)
|
||||
for result in results[0]["result"]
|
||||
]
|
Reference in New Issue
Block a user