mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 04:23:35 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -2,8 +2,8 @@ import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -62,7 +62,7 @@ if __name__ == "__main__":
|
||||
connection = _create_temporary_connection()
|
||||
vector_connector = _create_vector_connector()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=connection,
|
||||
connector=connection,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
|
@@ -2,10 +2,10 @@ import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -27,10 +27,10 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -4,12 +4,12 @@ from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -3,13 +3,13 @@
|
||||
if you not set vector_store_connector, it will return all tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
)
|
||||
```
|
||||
if you set vector_store_connector, it will recall topk similarity tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector
|
||||
)
|
||||
@@ -30,11 +30,10 @@ from pydantic import BaseModel, Field
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, InputOperator, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators import DBSchemaRetrieverOperator
|
||||
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
|
||||
from dbgpt.rag.operators import DBSchemaAssemblerOperator, DBSchemaRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -107,18 +106,19 @@ with DAG("simple_rag_db_schema_example") as dag:
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_operator = MapOperator(lambda request: request["query"])
|
||||
vector_store_connector = _create_vector_connector()
|
||||
connector = _create_temporary_connection()
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
join_operator = JoinOperator(combine_function=_join_fn)
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=_create_temporary_connection(),
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||
trigger >> request_handle_task >> assembler_task >> join_operator
|
||||
trigger >> assembler_task >> join_operator
|
||||
trigger >> request_handle_task >> query_operator >> join_operator
|
||||
join_operator >> retriever_task >> result_parse_task
|
||||
|
||||
|
@@ -17,12 +17,11 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import KnowledgeOperator
|
||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -22,15 +22,17 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import SummaryAssembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = "./docs/docs/awel.md"
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
llm_client = OpenAILLMClient()
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
|
Reference in New Issue
Block a user