mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-27 04:39:58 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
68
dbgpt/rag/assembler/tests/test_embedding_assembler.py
Normal file
68
dbgpt/rag/assembler/tests/test_embedding_assembler.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=VectorStoreConnector)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
||||
Reference in New Issue
Block a user