feat(rag): Support RAG SDK (#1322)

This commit is contained in:
Fangyin Cheng
2024-03-22 15:36:57 +08:00
committed by GitHub
parent e65732d6e4
commit 8a17099dd2
69 changed files with 1332 additions and 558 deletions

View File

@@ -10,7 +10,7 @@ from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.datasource.rdbms.base import RDBMSConnector
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.model.proxy import OpenAILLMClient
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.rag.operators.schema_linking import SchemaLinkingOperator
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
from dbgpt.storage.vector_store.connector import VectorStoreConnector
@@ -181,13 +181,13 @@ class SqlGenOperator(MapOperator[Any, Any]):
class SqlExecOperator(MapOperator[Any, Any]):
"""The Sql Execution Operator."""
def __init__(self, connection: Optional[RDBMSConnector] = None, **kwargs):
def __init__(self, connector: Optional[RDBMSConnector] = None, **kwargs):
"""
Args:
connection (Optional[RDBMSConnector]): RDBMSConnector connection
"""
super().__init__(**kwargs)
self._connection = connection
self._connector = connector
def map(self, sql: str) -> DataFrame:
"""retrieve table schemas.
@@ -196,7 +196,7 @@ class SqlExecOperator(MapOperator[Any, Any]):
Return:
str: sql execution
"""
dataframe = self._connection.run_to_df(command=sql, fetch="all")
dataframe = self._connector.run_to_df(command=sql, fetch="all")
print(f"sql data is \n{dataframe}")
return dataframe
@@ -237,12 +237,12 @@ with DAG("simple_nl_schema_sql_chart_example") as dag:
llm = OpenAILLMClient()
model_name = "gpt-3.5-turbo"
retriever_task = SchemaLinkingOperator(
connection=_create_temporary_connection(), llm=llm, model_name=model_name
connector=_create_temporary_connection(), llm=llm, model_name=model_name
)
prompt_join_operator = JoinOperator(combine_function=_prompt_join_fn)
sql_gen_operator = SqlGenOperator(llm=llm, model_name=model_name)
sql_exec_operator = SqlExecOperator(connection=_create_temporary_connection())
draw_chart_operator = ChartDrawOperator(connection=_create_temporary_connection())
sql_exec_operator = SqlExecOperator(connector=_create_temporary_connection())
draw_chart_operator = ChartDrawOperator(connector=_create_temporary_connection())
trigger >> request_handle_task >> query_operator >> prompt_join_operator
(
trigger