Files
DB-GPT/dbgpt/rag/assembler/tests/test_embedding_assembler.py
Cooper 9b0161e521 Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
2024-12-18 20:34:21 +08:00

83 lines
2.6 KiB
Python

from unittest.mock import MagicMock
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
"address": "TEXT",
"phone": "TEXT",
"email": "TEXT",
"gender": "TEXT",
"birthdate": "TEXT",
"occupation": "TEXT",
"education": "TEXT",
"marital_status": "TEXT",
"nationality": "TEXT",
"height": "REAL",
"weight": "REAL",
"blood_type": "TEXT",
"emergency_contact": "TEXT",
"created_at": "TEXT",
"updated_at": "TEXT",
}
}
}
)
return connect
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_table_vector_store_connector():
mock_connector = MagicMock(spec=VectorStoreConnector)
mock_connector.vector_store_config.name = "table_vector_store_name"
mock_connector.current_embeddings = DefaultEmbeddings()
return mock_connector
def test_load_knowledge(
mock_db_connection,
mock_chunk_parameters,
mock_embedding_factory,
mock_table_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = RDBTextSplitter(
separator="--table-field-separator--"
)
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBSchemaAssembler(
connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
table_vector_store_connector=mock_table_vector_store_connector,
max_seq_length=10,
)
assert len(assembler._chunks) > 1