mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
community[minor]: Add tidb loader support (#17788)
This pull request support loading data from TiDB database with Langchain. A simple usage: ``` from langchain_community.document_loaders import TiDBLoader CONNECTION_STRING = "mysql+pymysql://root@127.0.0.1:4000/test" QUERY = "select id, name, description from items;" loader = TiDBLoader( connection_string=CONNECTION_STRING, query=QUERY, page_content_columns=["name", "description"], metadata_columns=["id"], ) documents = loader.load() print(documents) ```
This commit is contained in:
@@ -199,6 +199,7 @@ from langchain_community.document_loaders.tensorflow_datasets import (
|
||||
TensorflowDatasetLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.text import TextLoader
|
||||
from langchain_community.document_loaders.tidb import TiDBLoader
|
||||
from langchain_community.document_loaders.tomarkdown import ToMarkdownLoader
|
||||
from langchain_community.document_loaders.toml import TomlLoader
|
||||
from langchain_community.document_loaders.trello import TrelloLoader
|
||||
@@ -380,6 +381,7 @@ __all__ = [
|
||||
"TencentCOSDirectoryLoader",
|
||||
"TencentCOSFileLoader",
|
||||
"TextLoader",
|
||||
"TiDBLoader",
|
||||
"ToMarkdownLoader",
|
||||
"TomlLoader",
|
||||
"TrelloLoader",
|
||||
|
71
libs/community/langchain_community/document_loaders/tidb.py
Normal file
71
libs/community/langchain_community/document_loaders/tidb.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
class TiDBLoader(BaseLoader):
|
||||
"""Load documents from TiDB."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
query: str,
|
||||
page_content_columns: Optional[List[str]] = None,
|
||||
metadata_columns: Optional[List[str]] = None,
|
||||
engine_args: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize TiDB document loader.
|
||||
|
||||
Args:
|
||||
connection_string (str): The connection string for the TiDB database,
|
||||
format: "mysql+pymysql://root@127.0.0.1:4000/test".
|
||||
query: The query to run in TiDB.
|
||||
page_content_columns: Optional. Columns written to Document `page_content`,
|
||||
default(None) to all columns.
|
||||
metadata_columns: Optional. Columns written to Document `metadata`,
|
||||
default(None) to no columns.
|
||||
engine_args: Optional. Additional arguments to pass to sqlalchemy engine.
|
||||
"""
|
||||
self.connection_string = connection_string
|
||||
self.query = query
|
||||
self.page_content_columns = page_content_columns
|
||||
self.metadata_columns = metadata_columns if metadata_columns is not None else []
|
||||
self.engine_args = engine_args
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load TiDB data into document objects."""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
# use sqlalchemy to create db connection
|
||||
engine: Engine = create_engine(
|
||||
self.connection_string, **(self.engine_args or {})
|
||||
)
|
||||
|
||||
# execute query
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(self.query))
|
||||
|
||||
# convert result to Document objects
|
||||
column_names = list(result.keys())
|
||||
for row in result:
|
||||
# convert row to dict{column:value}
|
||||
row_data = {
|
||||
column_names[index]: value for index, value in enumerate(row)
|
||||
}
|
||||
page_content = "\n".join(
|
||||
f"{k}: {v}"
|
||||
for k, v in row_data.items()
|
||||
if self.page_content_columns is None
|
||||
or k in self.page_content_columns
|
||||
)
|
||||
metadata = {col: row_data[col] for col in self.metadata_columns}
|
||||
yield Document(page_content=page_content, metadata=metadata)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load TiDB data into document objects."""
|
||||
return list(self.lazy_load())
|
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
|
||||
|
||||
from langchain_community.document_loaders import TiDBLoader
|
||||
|
||||
try:
|
||||
CONNECTION_STRING = os.getenv("TEST_TiDB_CONNECTION_URL", "")
|
||||
|
||||
if CONNECTION_STRING == "":
|
||||
raise OSError("TEST_TiDB_URL environment variable is not set")
|
||||
|
||||
tidb_available = True
|
||||
except (OSError, ImportError):
|
||||
tidb_available = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not tidb_available, reason="tidb is not available")
|
||||
def test_load_documents() -> None:
|
||||
"""Test loading documents from TiDB."""
|
||||
|
||||
# Connect to the database
|
||||
engine = create_engine(CONNECTION_STRING)
|
||||
metadata = MetaData()
|
||||
table_name = "tidb_loader_intergration_test"
|
||||
|
||||
# Create a test table
|
||||
test_table = Table(
|
||||
table_name,
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("name", String(255)),
|
||||
Column("description", String(255)),
|
||||
)
|
||||
metadata.create_all(engine)
|
||||
|
||||
with engine.connect() as connection:
|
||||
transaction = connection.begin()
|
||||
try:
|
||||
connection.execute(
|
||||
test_table.insert(),
|
||||
[
|
||||
{"name": "Item 1", "description": "Description of Item 1"},
|
||||
{"name": "Item 2", "description": "Description of Item 2"},
|
||||
{"name": "Item 3", "description": "Description of Item 3"},
|
||||
],
|
||||
)
|
||||
transaction.commit()
|
||||
except:
|
||||
transaction.rollback()
|
||||
raise
|
||||
|
||||
loader = TiDBLoader(
|
||||
connection_string=CONNECTION_STRING,
|
||||
query=f"SELECT * FROM {table_name};",
|
||||
page_content_columns=["name", "description"],
|
||||
metadata_columns=["id"],
|
||||
)
|
||||
documents = loader.load()
|
||||
test_table.drop(bind=engine)
|
||||
|
||||
# check
|
||||
assert len(documents) == 3
|
||||
assert (
|
||||
documents[0].page_content == "name: Item 1\ndescription: Description of Item 1"
|
||||
)
|
||||
assert documents[0].metadata == {"id": 1}
|
||||
assert (
|
||||
documents[1].page_content == "name: Item 2\ndescription: Description of Item 2"
|
||||
)
|
||||
assert documents[1].metadata == {"id": 2}
|
||||
assert (
|
||||
documents[2].page_content == "name: Item 3\ndescription: Description of Item 3"
|
||||
)
|
||||
assert documents[2].metadata == {"id": 3}
|
@@ -144,6 +144,7 @@ EXPECTED_ALL = [
|
||||
"TencentCOSDirectoryLoader",
|
||||
"TencentCOSFileLoader",
|
||||
"TextLoader",
|
||||
"TiDBLoader",
|
||||
"ToMarkdownLoader",
|
||||
"TomlLoader",
|
||||
"TrelloLoader",
|
||||
|
Reference in New Issue
Block a user