mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 12:59:43 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -10,7 +10,7 @@ from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators.schema_linking import SchemaLinkingOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -181,13 +181,13 @@ class SqlGenOperator(MapOperator[Any, Any]):
|
||||
class SqlExecOperator(MapOperator[Any, Any]):
|
||||
"""The Sql Execution Operator."""
|
||||
|
||||
def __init__(self, connection: Optional[RDBMSConnector] = None, **kwargs):
|
||||
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def map(self, sql: str) -> DataFrame:
|
||||
"""retrieve table schemas.
|
||||
@@ -196,7 +196,7 @@ class SqlExecOperator(MapOperator[Any, Any]):
|
||||
Return:
|
||||
str: sql execution
|
||||
"""
|
||||
dataframe = self._connection.run_to_df(command=sql, fetch="all")
|
||||
dataframe = self._connector.run_to_df(command=sql, fetch="all")
|
||||
print(f"sql data is \n{dataframe}")
|
||||
return dataframe
|
||||
|
||||
@@ -237,12 +237,12 @@ with DAG("simple_nl_schema_sql_chart_example") as dag:
|
||||
llm = OpenAILLMClient()
|
||||
model_name = "gpt-3.5-turbo"
|
||||
retriever_task = SchemaLinkingOperator(
|
||||
connection=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
connector=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
)
|
||||
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
|
||||
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
|
||||
sql_exec_operator = SqlExecOperator(connection=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connection=_create_temporary_connection())
|
||||
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
|
||||
trigger >> request_handle_task >> query_operator >> prompt_join_operator
|
||||
(
|
||||
trigger
|
||||
|
@@ -33,7 +33,7 @@ from typing import Dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.operators.rewrite import QueryRewriteOperator
|
||||
from dbgpt.rag.operators import QueryRewriteOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
|
@@ -31,9 +31,8 @@ from typing import Dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.operators.knowledge import KnowledgeOperator
|
||||
from dbgpt.rag.operators.summary import SummaryAssemblerOperator
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import KnowledgeOperator, SummaryAssemblerOperator
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
|
@@ -2,8 +2,8 @@ import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.serve.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -62,7 +62,7 @@ if __name__ == "__main__":
|
||||
connection = _create_temporary_connection()
|
||||
vector_connector = _create_vector_connector()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connection=connection,
|
||||
connector=connection,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
|
@@ -2,10 +2,10 @@ import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -27,10 +27,10 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import OpenAPIEmbeddings
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -4,12 +4,12 @@ from typing import Optional
|
||||
|
||||
from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import EmbeddingAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.evaluation import RetrieverEvaluator
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.rag.operators import EmbeddingRetrieverOperator
|
||||
from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -3,13 +3,13 @@
|
||||
if you not set vector_store_connector, it will return all tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
)
|
||||
```
|
||||
if you set vector_store_connector, it will recall topk similarity tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
connector=_create_temporary_connection()
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector
|
||||
)
|
||||
@@ -30,11 +30,10 @@ from pydantic import BaseModel, Field
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, InputOperator, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators import DBSchemaRetrieverOperator
|
||||
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
|
||||
from dbgpt.rag.operators import DBSchemaAssemblerOperator, DBSchemaRetrieverOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
@@ -107,18 +106,19 @@ with DAG("simple_rag_db_schema_example") as dag:
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_operator = MapOperator(lambda request: request["query"])
|
||||
vector_store_connector = _create_vector_connector()
|
||||
connector = _create_temporary_connection()
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=connector,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
join_operator = JoinOperator(combine_function=_join_fn)
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
connector=_create_temporary_connection(),
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||
trigger >> request_handle_task >> assembler_task >> join_operator
|
||||
trigger >> assembler_task >> join_operator
|
||||
trigger >> request_handle_task >> query_operator >> join_operator
|
||||
join_operator >> retriever_task >> result_parse_task
|
||||
|
||||
|
@@ -17,12 +17,11 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge import KnowledgeType
|
||||
from dbgpt.rag.operators import KnowledgeOperator
|
||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||
from dbgpt.rag.operators import EmbeddingAssemblerOperator, KnowledgeOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
|
@@ -22,15 +22,17 @@
|
||||
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from dbgpt.configs.model_config import ROOT_PATH
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.assembler import SummaryAssembler
|
||||
from dbgpt.rag.knowledge import KnowledgeFactory
|
||||
from dbgpt.serve.rag.assembler.summary import SummaryAssembler
|
||||
|
||||
|
||||
async def main():
|
||||
file_path = "./docs/docs/awel.md"
|
||||
file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
|
||||
llm_client = OpenAILLMClient()
|
||||
knowledge = KnowledgeFactory.from_file_path(file_path)
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
|
@@ -117,7 +117,7 @@ class SQLResultOperator(JoinOperator[Dict]):
|
||||
with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
db_connection = _create_temporary_connection()
|
||||
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
||||
retriever_task = DatasourceRetrieverOperator(connection=db_connection)
|
||||
retriever_task = DatasourceRetrieverOperator(connector=db_connection)
|
||||
# Merge the input data and the table structure information.
|
||||
prompt_input_task = JoinOperator(combine_function=_join_func)
|
||||
prompt_task = PromptBuilderOperator(_sql_prompt())
|
||||
@@ -125,7 +125,7 @@ with DAG("simple_sdk_llm_sql_example") as dag:
|
||||
llm_task = BaseLLMOperator(OpenAILLMClient())
|
||||
out_parse_task = SQLOutputParser()
|
||||
sql_parse_task = MapOperator(map_function=lambda x: x["sql"])
|
||||
db_query_task = DatasourceOperator(connection=db_connection)
|
||||
db_query_task = DatasourceOperator(connector=db_connection)
|
||||
sql_result_task = SQLResultOperator()
|
||||
input_task >> prompt_input_task
|
||||
input_task >> retriever_task >> prompt_input_task
|
||||
|
Reference in New Issue
Block a user