community[minor]: Add tidb loader support (#17788)

This pull request support loading data from TiDB database with
Langchain.

A simple usage:
```
from  langchain_community.document_loaders import TiDBLoader

CONNECTION_STRING = "mysql+pymysql://root@127.0.0.1:4000/test"

QUERY = "select id, name, description from items;"
loader = TiDBLoader(
    connection_string=CONNECTION_STRING,
    query=QUERY,
    page_content_columns=["name", "description"],
    metadata_columns=["id"],
)
documents = loader.load()
print(documents)
```
This commit is contained in:
Ian
2024-02-22 08:42:33 +08:00
committed by GitHub
parent 815ec74298
commit 3019a594b7
5 changed files with 339 additions and 0 deletions

View File

@@ -199,6 +199,7 @@ from langchain_community.document_loaders.tensorflow_datasets import (
TensorflowDatasetLoader,
)
from langchain_community.document_loaders.text import TextLoader
from langchain_community.document_loaders.tidb import TiDBLoader
from langchain_community.document_loaders.tomarkdown import ToMarkdownLoader
from langchain_community.document_loaders.toml import TomlLoader
from langchain_community.document_loaders.trello import TrelloLoader
@@ -380,6 +381,7 @@ __all__ = [
"TencentCOSDirectoryLoader",
"TencentCOSFileLoader",
"TextLoader",
"TiDBLoader",
"ToMarkdownLoader",
"TomlLoader",
"TrelloLoader",

View File

@@ -0,0 +1,71 @@
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
class TiDBLoader(BaseLoader):
"""Load documents from TiDB."""
def __init__(
self,
connection_string: str,
query: str,
page_content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
engine_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize TiDB document loader.
Args:
connection_string (str): The connection string for the TiDB database,
format: "mysql+pymysql://root@127.0.0.1:4000/test".
query: The query to run in TiDB.
page_content_columns: Optional. Columns written to Document `page_content`,
default(None) to all columns.
metadata_columns: Optional. Columns written to Document `metadata`,
default(None) to no columns.
engine_args: Optional. Additional arguments to pass to sqlalchemy engine.
"""
self.connection_string = connection_string
self.query = query
self.page_content_columns = page_content_columns
self.metadata_columns = metadata_columns if metadata_columns is not None else []
self.engine_args = engine_args
def lazy_load(self) -> Iterator[Document]:
"""Lazy load TiDB data into document objects."""
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.sql import text
# use sqlalchemy to create db connection
engine: Engine = create_engine(
self.connection_string, **(self.engine_args or {})
)
# execute query
with engine.connect() as conn:
result = conn.execute(text(self.query))
# convert result to Document objects
column_names = list(result.keys())
for row in result:
# convert row to dict{column:value}
row_data = {
column_names[index]: value for index, value in enumerate(row)
}
page_content = "\n".join(
f"{k}: {v}"
for k, v in row_data.items()
if self.page_content_columns is None
or k in self.page_content_columns
)
metadata = {col: row_data[col] for col in self.metadata_columns}
yield Document(page_content=page_content, metadata=metadata)
def load(self) -> List[Document]:
"""Load TiDB data into document objects."""
return list(self.lazy_load())

View File

@@ -0,0 +1,76 @@
import os
import pytest
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine
from langchain_community.document_loaders import TiDBLoader
try:
CONNECTION_STRING = os.getenv("TEST_TiDB_CONNECTION_URL", "")
if CONNECTION_STRING == "":
raise OSError("TEST_TiDB_URL environment variable is not set")
tidb_available = True
except (OSError, ImportError):
tidb_available = False
@pytest.mark.skipif(not tidb_available, reason="tidb is not available")
def test_load_documents() -> None:
"""Test loading documents from TiDB."""
# Connect to the database
engine = create_engine(CONNECTION_STRING)
metadata = MetaData()
table_name = "tidb_loader_intergration_test"
# Create a test table
test_table = Table(
table_name,
metadata,
Column("id", Integer, primary_key=True),
Column("name", String(255)),
Column("description", String(255)),
)
metadata.create_all(engine)
with engine.connect() as connection:
transaction = connection.begin()
try:
connection.execute(
test_table.insert(),
[
{"name": "Item 1", "description": "Description of Item 1"},
{"name": "Item 2", "description": "Description of Item 2"},
{"name": "Item 3", "description": "Description of Item 3"},
],
)
transaction.commit()
except:
transaction.rollback()
raise
loader = TiDBLoader(
connection_string=CONNECTION_STRING,
query=f"SELECT * FROM {table_name};",
page_content_columns=["name", "description"],
metadata_columns=["id"],
)
documents = loader.load()
test_table.drop(bind=engine)
# check
assert len(documents) == 3
assert (
documents[0].page_content == "name: Item 1\ndescription: Description of Item 1"
)
assert documents[0].metadata == {"id": 1}
assert (
documents[1].page_content == "name: Item 2\ndescription: Description of Item 2"
)
assert documents[1].metadata == {"id": 2}
assert (
documents[2].page_content == "name: Item 3\ndescription: Description of Item 3"
)
assert documents[2].metadata == {"id": 3}

View File

@@ -144,6 +144,7 @@ EXPECTED_ALL = [
"TencentCOSDirectoryLoader",
"TencentCOSFileLoader",
"TextLoader",
"TiDBLoader",
"ToMarkdownLoader",
"TomlLoader",
"TrelloLoader",