mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
fix: MySQL Database not support DDL init and upgrade. (#1133)
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
134
examples/rag/simple_dbschema_retriever_example.py
Normal file
134
examples/rag/simple_dbschema_retriever_example.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnect
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators.db_schema import DBSchemaRetrieverOperator
|
||||
from dbgpt.serve.rag.operators.db_schema import DBSchemaAssemblerOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
"""AWEL: Simple rag db schema embedding operator example
|
||||
|
||||
if you not set vector_store_connector, it will return all tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
)
|
||||
```
|
||||
if you set vector_store_connector, it will recall topk similarity tables schema in database.
|
||||
```
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection()
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector
|
||||
)
|
||||
```
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
curl --location 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/dbschema' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{"query": "what is user name?"}'
|
||||
"""
|
||||
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
).create(),
|
||||
)
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnect.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
)
|
||||
return connect
|
||||
|
||||
|
||||
def _join_fn(chunks: List[Chunk], query: str) -> str:
|
||||
print(f"db schema info is {[chunk.content for chunk in chunks]}")
|
||||
return query
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
query: str = Field(..., description="User query")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||
params = {
|
||||
"query": input_value.query,
|
||||
}
|
||||
print(f"Receive input value: {input_value}")
|
||||
return params
|
||||
|
||||
|
||||
with DAG("simple_rag_db_schema_example") as dag:
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/dbschema", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_operator = MapOperator(lambda request: request["query"])
|
||||
vector_store_connector = _create_vector_connector()
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
join_operator = JoinOperator(combine_function=_join_fn)
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
connection=_create_temporary_connection(),
|
||||
top_k=1,
|
||||
vector_store_connector=vector_store_connector,
|
||||
)
|
||||
result_parse_task = MapOperator(lambda chunks: [chunk.content for chunk in chunks])
|
||||
trigger >> request_handle_task >> assembler_task >> join_operator
|
||||
trigger >> request_handle_task >> query_operator >> join_operator
|
||||
join_operator >> retriever_task >> result_parse_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag], port=5555)
|
||||
else:
|
||||
pass
|
102
examples/rag/simple_rag_embedding_example.py
Normal file
102
examples/rag/simple_rag_embedding_example.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, MapOperator
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import KnowledgeType
|
||||
from dbgpt.rag.operators.knowledge import KnowledgeOperator
|
||||
from dbgpt.serve.rag.operators.embedding import EmbeddingAssemblerOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
"""AWEL: Simple rag embedding operator example
|
||||
|
||||
Examples:
|
||||
pre-requirements:
|
||||
python examples/awel/simple_rag_embedding_example.py
|
||||
..code-block:: shell
|
||||
curl --location --request POST 'http://127.0.0.1:5555/api/v1/awel/trigger/examples/rag/embedding' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"url": "https://docs.dbgpt.site/docs/awel"
|
||||
}'
|
||||
"""
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _create_vector_connector() -> VectorStoreConnector:
|
||||
"""Create vector connector."""
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
).create(),
|
||||
)
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||
params = {
|
||||
"url": input_value.url,
|
||||
}
|
||||
print(f"Receive input value: {input_value}")
|
||||
return params
|
||||
|
||||
|
||||
class ResultOperator(MapOperator):
|
||||
"""The Result Operator."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, chunks: List) -> str:
|
||||
result = f"embedding success, there are {len(chunks)} chunks."
|
||||
print(result)
|
||||
return result
|
||||
|
||||
|
||||
with DAG("simple_sdk_rag_embedding_example") as dag:
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/embedding", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
knowledge_operator = KnowledgeOperator(knowledge_type=KnowledgeType.URL)
|
||||
vector_connector = _create_vector_connector()
|
||||
url_parser_operator = MapOperator(map_function=lambda x: x["url"])
|
||||
embedding_operator = EmbeddingAssemblerOperator(
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
output_task = ResultOperator()
|
||||
(
|
||||
trigger
|
||||
>> request_handle_task
|
||||
>> url_parser_operator
|
||||
>> knowledge_operator
|
||||
>> embedding_operator
|
||||
>> output_task
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag], port=5555)
|
||||
else:
|
||||
pass
|
131
examples/rag/simple_rag_retriever_example.py
Normal file
131
examples/rag/simple_rag_retriever_example.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.model.proxy import OpenAILLMClient
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from dbgpt.rag.operators.embedding import EmbeddingRetrieverOperator
|
||||
from dbgpt.rag.operators.rerank import RerankOperator
|
||||
from dbgpt.rag.operators.rewrite import QueryRewriteOperator
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
from dbgpt.storage.vector_store.connector import VectorStoreConnector
|
||||
|
||||
"""AWEL: Simple rag embedding operator example
|
||||
|
||||
pre-requirements:
|
||||
1. install openai python sdk
|
||||
|
||||
```
|
||||
pip install openai
|
||||
```
|
||||
2. set openai key and base
|
||||
```
|
||||
export OPENAI_API_KEY={your_openai_key}
|
||||
export OPENAI_API_BASE={your_openai_base}
|
||||
```
|
||||
3. make sure you have vector store.
|
||||
if there are no data in vector store, please run examples/awel/simple_rag_embedding_example.py
|
||||
|
||||
|
||||
ensure your embedding model in DB-GPT/models/.
|
||||
|
||||
Examples:
|
||||
..code-block:: shell
|
||||
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/rag/retrieve \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"query": "what is awel talk about?"
|
||||
}'
|
||||
"""
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
query: str = Field(..., description="User query")
|
||||
|
||||
|
||||
class RequestHandleOperator(MapOperator[TriggerReqBody, Dict]):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict:
|
||||
params = {
|
||||
"query": input_value.query,
|
||||
}
|
||||
print(f"Receive input value: {input_value}")
|
||||
return params
|
||||
|
||||
|
||||
def _context_join_fn(context_dict: Dict, chunks: List[Chunk]) -> Dict:
|
||||
"""context Join function for JoinOperator.
|
||||
|
||||
Args:
|
||||
context_dict (Dict): context dict
|
||||
chunks (List[Chunk]): chunks
|
||||
Returns:
|
||||
Dict: context dict
|
||||
"""
|
||||
context_dict["context"] = "\n".join([chunk.content for chunk in chunks])
|
||||
return context_dict
|
||||
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
model_name = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="vector_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
).create(),
|
||||
)
|
||||
|
||||
|
||||
with DAG("simple_sdk_rag_retriever_example") as dag:
|
||||
vector_connector = _create_vector_connector()
|
||||
trigger = HttpTrigger(
|
||||
"/examples/rag/retrieve", methods="POST", request_body=TriggerReqBody
|
||||
)
|
||||
request_handle_task = RequestHandleOperator()
|
||||
query_parser = MapOperator(map_function=lambda x: x["query"])
|
||||
context_join_operator = JoinOperator(combine_function=_context_join_fn)
|
||||
rewrite_operator = QueryRewriteOperator(llm_client=OpenAILLMClient())
|
||||
retriever_context_operator = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
retriever_operator = EmbeddingRetrieverOperator(
|
||||
top_k=3,
|
||||
vector_store_connector=vector_connector,
|
||||
)
|
||||
rerank_operator = RerankOperator()
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
|
||||
trigger >> request_handle_task >> context_join_operator
|
||||
(
|
||||
trigger
|
||||
>> request_handle_task
|
||||
>> query_parser
|
||||
>> retriever_context_operator
|
||||
>> context_join_operator
|
||||
)
|
||||
context_join_operator >> rewrite_operator >> retriever_operator >> rerank_operator
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
# Development mode, you can run the dag locally for debugging.
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag], port=5555)
|
||||
else:
|
||||
pass
|
Reference in New Issue
Block a user