mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-19 00:14:40 +00:00
resolve conflict
This commit is contained in:
@@ -57,12 +57,10 @@ from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig, Chro
|
||||
shutil.rmtree("/tmp/tmp_ltm_vector_store", ignore_errors=True)
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(
|
||||
embedding_fn=embeddings,
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="ltm_vector_store",
|
||||
persist_path="/tmp/tmp_ltm_vector_store",
|
||||
),
|
||||
)
|
||||
persist_path="/tmp/tmp_ltm_vector_store",
|
||||
),
|
||||
name="ltm_vector_store",
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
```
|
||||
|
||||
|
@@ -47,13 +47,11 @@ from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig, Chro
|
||||
# Delete old vector store directory(/tmp/tmp_ltm_vector_stor)
|
||||
shutil.rmtree("/tmp/tmp_ltm_vector_store", ignore_errors=True)
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(
|
||||
embedding_fn=embeddings,
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="ltm_vector_store",
|
||||
persist_path="/tmp/tmp_ltm_vector_store",
|
||||
),
|
||||
)
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
persist_path="/tmp/tmp_ltm_vector_store",
|
||||
),
|
||||
name="ltm_vector_store",
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
```
|
||||
|
||||
|
@@ -11,7 +11,7 @@ In this example, we will load your knowledge from a URL and store it in a vector
|
||||
First, you need to install the `dbgpt` library.
|
||||
|
||||
```bash
|
||||
pip install "dbgpt[rag]>=0.5.2"
|
||||
pip install "dbgpt[agent,simple_framework, client]>=0.7.1" "dbgpt_ext>=0.7.1" -U
|
||||
````
|
||||
|
||||
### Prepare Embedding Model
|
||||
@@ -84,10 +84,10 @@ shutil.rmtree("/tmp/awel_rag_test_vector_store", ignore_errors=True)
|
||||
|
||||
vector_store = ChromaStore(
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="test_vstore",
|
||||
persist_path="/tmp/awel_rag_test_vector_store",
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
persist_path="/tmp/awel_rag_test_vector_store"
|
||||
),
|
||||
name="test_vstore",
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
|
||||
with DAG("load_knowledge_dag") as knowledge_dag:
|
||||
@@ -274,10 +274,10 @@ shutil.rmtree("/tmp/awel_rag_test_vector_store", ignore_errors=True)
|
||||
|
||||
vector_store = ChromaStore(
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="test_vstore",
|
||||
persist_path="/tmp/awel_rag_test_vector_store",
|
||||
embedding_fn=embeddings
|
||||
),
|
||||
name="test_vstore",
|
||||
embedding_fn=embeddings
|
||||
)
|
||||
|
||||
with DAG("load_knowledge_dag") as knowledge_dag:
|
||||
|
@@ -29,7 +29,8 @@ In this guide, we mainly focus on step 1, 2, and 3.
|
||||
First, you need to install the `dbgpt` library.
|
||||
|
||||
```bash
|
||||
pip install "dbgpt[rag]>=0.7.0" -U
|
||||
pip install "dbgpt[rag, agent, client, simple_framework]>=0.7.0" "dbgpt_ext>=0.7.0" -U
|
||||
pip install openai
|
||||
````
|
||||
|
||||
## Build Knowledge Base
|
||||
@@ -92,9 +93,9 @@ shutil.rmtree("/tmp/awel_with_data_vector_store", ignore_errors=True)
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(
|
||||
persist_path="/tmp/tmp_ltm_vector_store",
|
||||
name="ltm_vector_store",
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
),
|
||||
name="ltm_vector_store",
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
|
||||
with DAG("load_schema_dag") as load_schema_dag:
|
||||
@@ -102,7 +103,7 @@ with DAG("load_schema_dag") as load_schema_dag:
|
||||
# Load database schema to vector store
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connector=db_conn,
|
||||
index_store=vector_store,
|
||||
table_vector_store_connector=vector_store,
|
||||
chunk_parameters=ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
)
|
||||
input_task >> assembler_task
|
||||
@@ -122,7 +123,8 @@ with DAG("retrieve_schema_dag") as retrieve_schema_dag:
|
||||
# Retrieve database schema from vector store
|
||||
retriever_task = DBSchemaRetrieverOperator(
|
||||
top_k=1,
|
||||
index_store=vector_store,
|
||||
table_vector_store_connector=vector_store,
|
||||
field_vector_store_connector=vector_store
|
||||
)
|
||||
input_task >> retriever_task
|
||||
|
||||
@@ -487,10 +489,10 @@ db_conn.create_temp_tables(
|
||||
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(
|
||||
embedding_fn=embeddings,
|
||||
name="db_schema_vector_store",
|
||||
persist_path="/tmp/awel_with_data_vector_store",
|
||||
)
|
||||
),
|
||||
embedding_fn=embeddings,
|
||||
name="db_schema_vector_store",
|
||||
)
|
||||
|
||||
antv_charts = [
|
||||
@@ -623,7 +625,7 @@ with DAG("load_schema_dag") as load_schema_dag:
|
||||
# Load database schema to vector store
|
||||
assembler_task = DBSchemaAssemblerOperator(
|
||||
connector=db_conn,
|
||||
index_store=vector_store,
|
||||
table_vector_store_connector=vector_store,
|
||||
chunk_parameters=ChunkParameters(chunk_strategy="CHUNK_BY_SIZE"),
|
||||
)
|
||||
input_task >> assembler_task
|
||||
|
@@ -51,16 +51,16 @@ INPUT_PROMPT = "\n###Input:\n{}\n###Response:"
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
config = ChromaVectorConfig(persist_path=PILOT_PATH)
|
||||
|
||||
return ChromaStore(
|
||||
config,
|
||||
name="embedding_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
def _create_temporary_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
|
@@ -60,12 +60,12 @@ db_conn.create_temp_tables(
|
||||
}
|
||||
)
|
||||
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
config = ChromaVectorConfig(persist_path=PILOT_PATH)
|
||||
vector_store = ChromaStore(
|
||||
config,
|
||||
name="db_schema_vector_store",
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
vector_store = ChromaStore(config)
|
||||
|
||||
antv_charts = [
|
||||
{"response_line_chart": "used to display comparative trend analysis data"},
|
||||
|
@@ -326,7 +326,7 @@ def check_chat_request(request: ChatCompletionRequestBody = Body()):
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "chart param is None",
|
||||
"message": "chat param is None",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_chat_param",
|
||||
|
@@ -94,11 +94,9 @@ class HybridMemory(Memory, Generic[T]):
|
||||
vstore_path = vstore_path or os.path.join(DATA_DIR, "agent_memory")
|
||||
|
||||
vector_store = ChromaStore(
|
||||
ChromaVectorConfig(
|
||||
name=vstore_name,
|
||||
persist_path=vstore_path,
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
ChromaVectorConfig(persist_path=vstore_path),
|
||||
name=vstore_name,
|
||||
embedding_fn=embeddings,
|
||||
)
|
||||
return cls.from_vstore(
|
||||
vector_store=vector_store,
|
||||
|
@@ -7,12 +7,12 @@ class VisThinking(Vis):
|
||||
@classmethod
|
||||
async def build_message(cls, message: str) -> str:
|
||||
vis = VisThinking()
|
||||
return f"```{vis.vis_tag()}\n{message}\n```"
|
||||
return f"``````{vis.vis_tag()}\n{message}\n``````"
|
||||
|
||||
def sync_display(self, **kwargs) -> str:
|
||||
"""Display the content using the vis protocol."""
|
||||
content = kwargs.get("content")
|
||||
return f"```{self.vis_tag()}\n{content}\n```"
|
||||
return f"``````{self.vis_tag()}\n{content}\n``````"
|
||||
|
||||
@classmethod
|
||||
def vis_tag(cls):
|
||||
|
@@ -1,9 +1,12 @@
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util import PaginationResult
|
||||
@@ -231,6 +234,89 @@ async def get_history_messages(con_uid: str, service: Service = Depends(get_serv
|
||||
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/export_messages",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def export_all_messages(
|
||||
user_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
format: Literal["file", "json"] = Query(
|
||||
"file", description="response format(file or json)"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Export all conversations and messages for a user
|
||||
|
||||
Args:
|
||||
user_name (str): The user name
|
||||
user_id (str): The user id (alternative to user_name)
|
||||
sys_code (str): The system code
|
||||
format (str): The format of the response, either 'file' or 'json', defaults to
|
||||
'file'
|
||||
|
||||
Returns:
|
||||
A dictionary containing all conversations and their messages
|
||||
"""
|
||||
# 1. Get all conversations for the user
|
||||
request = ServeRequest(
|
||||
user_name=user_name or user_id,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
|
||||
# Initialize pagination variables
|
||||
page = 1
|
||||
page_size = 100 # Adjust based on your needs
|
||||
all_conversations = []
|
||||
|
||||
# Paginate through all conversations
|
||||
while True:
|
||||
pagination_result = service.get_list_by_page(request, page, page_size)
|
||||
all_conversations.extend(pagination_result.items)
|
||||
|
||||
if page >= pagination_result.total_pages:
|
||||
break
|
||||
page += 1
|
||||
|
||||
# 2. For each conversation, get all messages
|
||||
result = {
|
||||
"user_name": user_name or user_id,
|
||||
"sys_code": sys_code,
|
||||
"total_conversations": len(all_conversations),
|
||||
"conversations": [],
|
||||
}
|
||||
|
||||
for conv in all_conversations:
|
||||
messages = service.get_history_messages(ServeRequest(conv_uid=conv.conv_uid))
|
||||
conversation_data = {
|
||||
"conv_uid": conv.conv_uid,
|
||||
"chat_mode": conv.chat_mode,
|
||||
"app_code": conv.app_code,
|
||||
"create_time": conv.gmt_created,
|
||||
"update_time": conv.gmt_modified,
|
||||
"total_messages": len(messages),
|
||||
"messages": [msg.dict() for msg in messages],
|
||||
}
|
||||
result["conversations"].append(conversation_data)
|
||||
|
||||
if format == "json":
|
||||
return JSONResponse(content=result)
|
||||
else:
|
||||
file_name = (
|
||||
f"conversation_export_{user_name or user_id or 'dbgpt'}_"
|
||||
f"{sys_code or 'dbgpt'}"
|
||||
)
|
||||
# Return the json file
|
||||
return StreamingResponse(
|
||||
io.BytesIO(
|
||||
json.dumps(result, ensure_ascii=False, indent=4).encode("utf-8")
|
||||
),
|
||||
media_type="application/file",
|
||||
headers={"Content-Disposition": f"attachment;filename={file_name}.json"},
|
||||
)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
|
@@ -109,6 +109,20 @@ class ServerResponse(BaseModel):
|
||||
"dbgpt",
|
||||
],
|
||||
)
|
||||
gmt_created: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The record creation time.",
|
||||
examples=[
|
||||
"2023-01-07 09:00:00",
|
||||
],
|
||||
)
|
||||
gmt_modified: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The record update time.",
|
||||
examples=[
|
||||
"2023-01-07 09:00:00",
|
||||
],
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
|
@@ -60,6 +60,17 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
RES: The response
|
||||
"""
|
||||
# TODO implement your own logic here, transfer the entity to a response
|
||||
gmt_created = (
|
||||
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_created
|
||||
else None
|
||||
)
|
||||
gmt_modified = (
|
||||
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_modified
|
||||
else None
|
||||
)
|
||||
|
||||
return ServerResponse(
|
||||
app_code=entity.app_code,
|
||||
conv_uid=entity.conv_uid,
|
||||
@@ -67,6 +78,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
chat_mode=entity.chat_mode,
|
||||
user_name="",
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created,
|
||||
gmt_modified=gmt_modified,
|
||||
)
|
||||
|
||||
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
|
||||
|
Reference in New Issue
Block a user