mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-28 21:12:13 +00:00
feat(rag): Support RAG SDK (#1322)
This commit is contained in:
@@ -10,7 +10,7 @@ from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.base import RDBMSConnector
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators.schema_linking import SchemaLinkingOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
@@ -181,13 +181,13 @@ class SqlGenOperator(MapOperator[Any, Any]):
|
||||
class SqlExecOperator(MapOperator[Any, Any]):
|
||||
"""The Sql Execution Operator."""
|
||||
|
||||
def __init__(self, connection: Optional[RDBMSConnector] = None, **kwargs):
|
||||
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
connection (Optional[RDBMSConnector]): RDBMSConnector connection
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._connection = connection
|
||||
self._connector = connector
|
||||
|
||||
def map(self, sql: str) -> DataFrame:
|
||||
"""retrieve table schemas.
|
||||
@@ -196,7 +196,7 @@ class SqlExecOperator(MapOperator[Any, Any]):
|
||||
Return:
|
||||
str: sql execution
|
||||
"""
|
||||
dataframe = self._connection.run_to_df(command=sql, fetch="all")
|
||||
dataframe = self._connector.run_to_df(command=sql, fetch="all")
|
||||
print(f"sql data is \n{dataframe}")
|
||||
return dataframe
|
||||
|
||||
@@ -237,12 +237,12 @@ with DAG("simple_nl_schema_sql_chart_example") as dag:
|
||||
llm = OpenAILLMClient()
|
||||
model_name = "gpt-3.5-turbo"
|
||||
retriever_task = SchemaLinkingOperator(
|
||||
connection=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
connector=_create_temporary_connection(), llm=llm, model_name=model_name
|
||||
)
|
||||
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
|
||||
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
|
||||
sql_exec_operator = SqlExecOperator(connection=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connection=_create_temporary_connection())
|
||||
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
|
||||
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
|
||||
trigger >> request_handle_task >> query_operator >> prompt_join_operator
|
||||
(
|
||||
trigger
|
||||
|
Reference in New Issue
Block a user