From dbfa2c7c48073cb9ff0c6ce359264423510bf04e Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 17 Mar 2025 14:30:56 +0800 Subject: [PATCH] chore: Merge main codes --- packages/dbgpt-app/src/dbgpt_app/config.py | 5 + .../dbgpt-app/src/dbgpt_app/dbgpt_server.py | 3 + .../initialization/app_initialization.py | 19 +++ .../dbgpt-app/src/dbgpt_app/knowledge/api.py | 146 +++++++++++------- .../src/dbgpt_app/knowledge/service.py | 18 ++- .../src/dbgpt_app/openapi/api_v1/api_v1.py | 29 ++-- .../dbgpt-app/src/dbgpt_app/openapi/api_v2.py | 23 +-- .../dbgpt-app/src/dbgpt_app/scene/__init__.py | 2 +- .../dbgpt-app/src/dbgpt_app/scene/base.py | 6 - .../src/dbgpt_app/scene/base_chat.py | 142 +++++++++++------ .../dbgpt_app/scene/chat_dashboard/chat.py | 41 +++-- .../dbgpt_app/scene/chat_dashboard/config.py | 25 +++ .../scene/chat_dashboard/out_parser.py | 2 +- .../dbgpt_app/scene/chat_dashboard/prompt.py | 12 +- .../scene/chat_data/chat_excel/config.py | 45 ++++++ .../chat_excel/excel_analyze/chat.py | 60 ++++--- .../chat_excel/excel_analyze/out_parser.py | 2 +- .../chat_excel/excel_analyze/prompt.py | 3 +- .../chat_excel/excel_learning/chat.py | 30 ++-- .../chat_excel/excel_learning/out_parser.py | 2 +- .../chat_excel/excel_learning/prompt.py | 3 +- .../chat_data/chat_excel/excel_reader.py | 89 +++++++++-- .../scene/chat_db/auto_execute/chat.py | 34 ++-- .../scene/chat_db/auto_execute/config.py | 40 +++++ .../scene/chat_db/auto_execute/out_parser.py | 2 +- .../scene/chat_db/auto_execute/prompt.py | 5 +- .../chat_db/auto_execute/prompt_baichuan.py | 68 -------- .../scene/chat_db/professional_qa/chat.py | 37 +++-- .../scene/chat_db/professional_qa/config.py | 40 +++++ .../scene/chat_db/professional_qa/prompt.py | 3 +- .../src/dbgpt_app/scene/chat_factory.py | 16 +- .../chat_knowledge/refine_summary/chat.py | 23 +-- .../refine_summary/out_parser.py | 2 +- .../chat_knowledge/refine_summary/prompt.py | 7 +- .../dbgpt_app/scene/chat_knowledge/v1/chat.py | 57 ++++--- .../scene/chat_knowledge/v1/config.py | 37 +++++ .../scene/chat_knowledge/v1/prompt.py | 3 +- .../scene/chat_knowledge/v1/prompt_chatglm.py | 3 +- .../src/dbgpt_app/scene/chat_normal/chat.py | 17 +- .../src/dbgpt_app/scene/chat_normal/config.py | 23 +++ .../src/dbgpt_app/scene/chat_normal/prompt.py | 6 +- .../dbgpt_app/scene/operators/app_operator.py | 35 +++-- .../src/dbgpt/configs/model_config.py | 3 + .../src/dbgpt/core/interface/file.py | 1 + .../src/dbgpt/core/interface/llm.py | 11 +- .../dbgpt-core/src/dbgpt/util/config_utils.py | 38 ++++- .../dbgpt-core/src/dbgpt/util/module_utils.py | 7 +- .../src/dbgpt/util/parameter_utils.py | 6 +- .../src/dbgpt_serve/core/config.py | 105 ++++++++++++- .../dbgpt_serve/datasource/api/endpoints.py | 18 ++- .../dbgpt_serve/datasource/service/service.py | 26 ++-- .../src/dbgpt_serve/rag/api/endpoints.py | 17 +- .../src/dbgpt_serve/rag/service/service.py | 80 +++++++--- .../src/dbgpt_serve/rag/tests/test_service.py | 2 +- 54 files changed, 1024 insertions(+), 455 deletions(-) create mode 100644 packages/dbgpt-app/src/dbgpt_app/initialization/app_initialization.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/config.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/config.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/config.py delete mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt_baichuan.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/config.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/config.py create mode 100644 packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/config.py diff --git a/packages/dbgpt-app/src/dbgpt_app/config.py b/packages/dbgpt-app/src/dbgpt_app/config.py index d13b4b304..6d1ddf637 100644 --- a/packages/dbgpt-app/src/dbgpt_app/config.py +++ b/packages/dbgpt-app/src/dbgpt_app/config.py @@ -17,6 +17,7 @@ from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig from dbgpt_serve.core import BaseServeConfig +from dbgpt_serve.core.config import GPTsAppConfig @dataclass @@ -361,6 +362,10 @@ class ApplicationConfig(BaseParameters): default_factory=lambda: RagParameters(), metadata={"help": _("Rag Knowledge Parameters")}, ) + app: GPTsAppConfig = field( + default_factory=lambda: GPTsAppConfig(), + metadata={"help": _("GPTs application configuration")}, + ) trace: TracerParameters = field( default_factory=TracerParameters, metadata={ diff --git a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py index 908a9a6c4..4f72a2ccf 100644 --- a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py +++ b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py @@ -256,6 +256,7 @@ def run_webserver(config_file: str): def scan_configs(): from dbgpt.model import scan_model_providers + from dbgpt_app.initialization.app_initialization import scan_app_configs from dbgpt_app.initialization.serve_initialization import scan_serve_configs from dbgpt_ext.storage import scan_storage_configs from dbgpt_serve.datasource.manages.connector_manager import ConnectorManager @@ -269,6 +270,8 @@ def scan_configs(): scan_serve_configs() # Register all storage configs scan_storage_configs() + # Register all app configs + scan_app_configs() def load_config(config_file: str = None) -> ApplicationConfig: diff --git a/packages/dbgpt-app/src/dbgpt_app/initialization/app_initialization.py b/packages/dbgpt-app/src/dbgpt_app/initialization/app_initialization.py new file mode 100644 index 000000000..e1c15d9f3 --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/initialization/app_initialization.py @@ -0,0 +1,19 @@ +from dbgpt_serve.core.config import GPTsAppCommonConfig + + +def scan_app_configs(): + """Scan and register all app configs.""" + from dbgpt.util.module_utils import ModelScanner, ScannerConfig + + modules = ["dbgpt_app.scene"] + + scanner = ModelScanner[GPTsAppCommonConfig]() + for module in modules: + config = ScannerConfig( + module_path=module, + base_class=GPTsAppCommonConfig, + recursive=True, + specific_files=["config"], + ) + scanner.scan_and_register(config) + return scanner.get_registered_items() diff --git a/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py b/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py index 775f14156..b6ce9833d 100644 --- a/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py +++ b/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py @@ -1,10 +1,9 @@ import logging import os import shutil -import tempfile from typing import List -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi import APIRouter, Depends, File, Form, UploadFile from dbgpt._private.config import Config from dbgpt.configs import TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE @@ -12,8 +11,10 @@ from dbgpt.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, ) from dbgpt.core.awel.dag.dag_manager import DAGManager +from dbgpt.core.interface.file import FileStorageClient from dbgpt.rag.retriever import BaseRetriever from dbgpt.rag.retriever.embedding import EmbeddingRetriever +from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.i18n_utils import _ from dbgpt.util.tracer import SpanType, root_tracer from dbgpt_app.knowledge.request.request import ( @@ -34,7 +35,11 @@ from dbgpt_app.knowledge.request.response import ( KnowledgeQueryResponse, ) from dbgpt_app.knowledge.service import KnowledgeService -from dbgpt_app.openapi.api_v1.api_v1 import no_stream_generator, stream_generator +from dbgpt_app.openapi.api_v1.api_v1 import ( + get_executor, + no_stream_generator, + stream_generator, +) from dbgpt_app.openapi.api_view_model import Result from dbgpt_ext.rag import ChunkParameters from dbgpt_ext.rag.chunk_manager import ChunkStrategy @@ -71,21 +76,30 @@ def get_dag_manager() -> DAGManager: return DAGManager.get_instance(CFG.SYSTEM_APP) +def get_fs() -> FileStorageClient: + return FileStorageClient.get_instance(CFG.SYSTEM_APP) + + @router.post("/knowledge/space/add") -def space_add(request: KnowledgeSpaceRequest): - print(f"/space/add params: {request}") +async def space_add(request: KnowledgeSpaceRequest): + logger.info(f"/space/add params: {request}") try: - knowledge_space_service.create_knowledge_space(request) + await blocking_func_to_async( + get_executor(), knowledge_space_service.create_knowledge_space, request + ) return Result.succ([]) except Exception as e: return Result.failed(code="E000X", msg=f"space add error {e}") @router.post("/knowledge/space/list") -def space_list(request: KnowledgeSpaceRequest): - print("/space/list params:") +async def space_list(request: KnowledgeSpaceRequest): + logger.info(f"/space/list params: {request}") try: - return Result.succ(knowledge_space_service.get_knowledge_space(request)) + res = await blocking_func_to_async( + get_executor(), knowledge_space_service.get_knowledge_space, request + ) + return Result.succ(res) except Exception as e: logger.exception(f"Space list error!{str(e)}") return Result.failed(code="E000X", msg=f"space list error {e}") @@ -93,8 +107,7 @@ def space_list(request: KnowledgeSpaceRequest): @router.post("/knowledge/space/delete") def space_delete(request: KnowledgeSpaceRequest): - print("/space/delete params:") - print(request.name) + logger.info(f"/space/delete params: {request}") try: # delete Files in 'pilot/data/ safe_space_name = os.path.basename(request.name) @@ -107,17 +120,20 @@ def space_delete(request: KnowledgeSpaceRequest): if os.path.exists(space_dir): shutil.rmtree(space_dir) except Exception as e: - print(e) + logger.error(f"Failed to remove {safe_space_name}: {str(e)}") return Result.succ(knowledge_space_service.delete_space(request.name)) except Exception as e: return Result.failed(code="E000X", msg=f"space delete error {e}") @router.post("/knowledge/{space_id}/arguments") -def arguments(space_id: str): - print("/knowledge/space/arguments params:") +async def arguments(space_id: str): + logger.info(f"/knowledge/{space_id}/arguments params: {space_id}") try: - return Result.succ(knowledge_space_service.arguments(space_id)) + res = await blocking_func_to_async( + get_executor(), knowledge_space_service.arguments, space_id + ) + return Result.succ(res) except Exception as e: return Result.failed(code="E000X", msg=f"space arguments error {e}") @@ -127,7 +143,7 @@ async def recall_test( space_name: str, request: DocumentRecallTestRequest, ): - print(f"/knowledge/{space_name}/recall_test params:") + logger.info(f"/knowledge/{space_name}/recall_test params: {request}") try: return Result.succ( await knowledge_space_service.recall_test(space_name, request) @@ -140,7 +156,7 @@ async def recall_test( def recall_retrievers( space_id: str, ): - print(f"/knowledge/{space_id}/recall_retrievers params:") + logger.info(f"/knowledge/{space_id}/recall_retrievers params:") try: logger.info(f"get_recall_retrievers {space_id}") @@ -177,25 +193,31 @@ def recall_retrievers( @router.post("/knowledge/{space_id}/argument/save") -def arguments_save(space_id: str, argument_request: SpaceArgumentRequest): +async def arguments_save(space_id: str, argument_request: SpaceArgumentRequest): print("/knowledge/space/argument/save params:") try: - return Result.succ( - knowledge_space_service.argument_save(space_id, argument_request) + res = await blocking_func_to_async( + get_executor(), + knowledge_space_service.argument_save, + space_id, + argument_request, ) + return Result.succ(res) except Exception as e: return Result.failed(code="E000X", msg=f"space save error {e}") @router.post("/knowledge/{space_name}/document/add") -def document_add(space_name: str, request: KnowledgeDocumentRequest): - print(f"/document/add params: {space_name}, {request}") +async def document_add(space_name: str, request: KnowledgeDocumentRequest): + logger.info(f"/document/add params: {space_name}, {request}") try: - return Result.succ( - knowledge_space_service.create_knowledge_document( - space=space_name, request=request - ) + res = await blocking_func_to_async( + get_executor(), + knowledge_space_service.create_knowledge_document, + space=space_name, + request=request, ) + return Result.succ(res) # return Result.succ([]) except Exception as e: return Result.failed(code="E000X", msg=f"document add error {e}") @@ -207,7 +229,7 @@ def document_edit( request: KnowledgeDocumentRequest, service: Service = Depends(get_rag_service), ): - print(f"/document/edit params: {space_name}, {request}") + logger.info(f"/document/edit params: {space_name}, {request}") space = service.get({"name": space_name}) if space is None: return Result.failed( @@ -263,7 +285,11 @@ async def space_config() -> Result[KnowledgeConfigResponse]: dag_manager: DAGManager = get_dag_manager() # Vector Storage vs_domain_types = [KnowledgeDomainType(name="Normal", desc="Normal")] - dag_map = dag_manager.get_dags_by_tag_key(TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE) + dag_map = await blocking_func_to_async( + get_executor(), + dag_manager.get_dags_by_tag_key, + TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE, + ) for domain_type, dags in dag_map.items(): vs_domain_types.append( KnowledgeDomainType( @@ -318,8 +344,7 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/graphvis") def graph_vis(space_name: str, query_request: GraphVisRequest): - print(f"/document/list params: {space_name}, {query_request}") - print(query_request.limit) + logger.info(f"/document/list params: {space_name}, {query_request}") try: return Result.succ( knowledge_space_service.query_graph( @@ -347,63 +372,64 @@ async def document_upload( doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...), + fs: FileStorageClient = Depends(get_fs), ): print(f"/document/upload params: {space_name}") try: if doc_file: + safe_filename = os.path.basename(doc_file.filename) # Sanitize inputs to prevent path traversal safe_space_name = os.path.basename(space_name) - safe_filename = os.path.basename(doc_file.filename) - # Create absolute paths and verify they are within allowed directory - upload_dir = os.path.abspath( - os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, safe_space_name) + custom_metadata = { + "space_name": space_name, + "doc_name": doc_name, + "doc_type": doc_type, + } + bucket = "dbgpt_knowledge_file" + file_uri = await blocking_func_to_async( + get_executor(), + fs.save_file, + bucket, + safe_filename, + doc_file.file, + storage_type="distributed", + custom_metadata=custom_metadata, ) - target_path = os.path.abspath(os.path.join(upload_dir, safe_filename)) - - if os.path.abspath(KNOWLEDGE_UPLOAD_ROOT_PATH) not in target_path: - raise HTTPException(status_code=400, detail="Invalid path detected") - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - - # Create temp file - tmp_fd, tmp_path = tempfile.mkstemp(dir=upload_dir) try: - with os.fdopen(tmp_fd, "wb") as tmp: - tmp.write(await doc_file.read()) - - shutil.move(tmp_path, target_path) - request = KnowledgeDocumentRequest() request.doc_name = doc_name request.doc_type = doc_type - request.content = target_path + request.content = file_uri - space_res = knowledge_space_service.get_knowledge_space( - KnowledgeSpaceRequest(name=safe_space_name) + space_res = await blocking_func_to_async( + get_executor(), + knowledge_space_service.get_knowledge_space, + KnowledgeSpaceRequest(name=safe_space_name), ) if len(space_res) == 0: # create default space if "default" != safe_space_name: raise Exception("you have not create your knowledge space.") - knowledge_space_service.create_knowledge_space( + await blocking_func_to_async( + get_executor(), + knowledge_space_service.create_knowledge_space, KnowledgeSpaceRequest( name=safe_space_name, desc="first db-gpt rag application", owner="dbgpt", - ) - ) - return Result.succ( - knowledge_space_service.create_knowledge_document( - space=safe_space_name, request=request + ), ) + res = await blocking_func_to_async( + get_executor(), + knowledge_space_service.create_knowledge_document, + space=safe_space_name, + request=request, ) + return Result.succ(res) except Exception as e: # Clean up temp file if anything goes wrong - if os.path.exists(tmp_path): - os.unlink(tmp_path) raise e return Result.failed(code="E000X", msg="doc_file is None") diff --git a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py index 2f8eaec78..6e8068daf 100644 --- a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py +++ b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py @@ -586,15 +586,16 @@ class KnowledgeService: Returns: chat: BaseChat, refine summary chat. """ - from dbgpt_app.scene import ChatScene + from dbgpt_app.scene import ChatParam, ChatScene - chat_param = { - "chat_session_id": conn_uid, - "current_user_input": "", - "select_param": doc, - "model_name": model_name, - "model_cache_enable": False, - } + chat_param = ChatParam( + chat_session_id=conn_uid, + current_user_input="", + select_param=doc, + model_name=model_name, + model_cache_enable=False, + chat_mode=ChatScene.ExtractRefineSummary, + ) executor = CFG.SYSTEM_APP.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() @@ -604,6 +605,7 @@ class KnowledgeService: executor, CHAT_FACTORY.get_implementation, ChatScene.ExtractRefineSummary.value(), + CFG.SYSTEM_APP, **{"chat_param": chat_param}, ) return chat diff --git a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/api_v1.py b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/api_v1.py index 84324dfff..b87875078 100644 --- a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/api_v1.py +++ b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v1/api_v1.py @@ -44,7 +44,7 @@ from dbgpt_app.openapi.api_view_model import ( MessageVo, Result, ) -from dbgpt_app.scene import BaseChat, ChatFactory, ChatScene +from dbgpt_app.scene import BaseChat, ChatFactory, ChatParam, ChatScene from dbgpt_serve.agent.db.gpts_app import UserRecentAppsDao, adapt_native_app_model from dbgpt_serve.core import blocking_func_to_async from dbgpt_serve.datasource.manages.db_conn_info import DBConfig, DbTypeInfo @@ -454,19 +454,20 @@ async def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: Result.failed("Unsupported Chat Mode," + dialogue.chat_mode + "!") ) - chat_param = { - "chat_session_id": dialogue.conv_uid, - "user_name": dialogue.user_name, - "sys_code": dialogue.sys_code, - "current_user_input": dialogue.user_input, - "select_param": dialogue.select_param, - "model_name": dialogue.model_name, - "app_code": dialogue.app_code, - "ext_info": dialogue.ext_info, - "temperature": dialogue.temperature, - "max_new_tokens": dialogue.max_new_tokens, - "prompt_code": dialogue.prompt_code, - } + chat_param = ChatParam( + chat_session_id=dialogue.conv_uid, + user_name=dialogue.user_name, + sys_code=dialogue.sys_code, + current_user_input=dialogue.user_input, + select_param=dialogue.select_param, + model_name=dialogue.model_name, + app_code=dialogue.app_code, + ext_info=dialogue.ext_info, + temperature=dialogue.temperature, + max_new_tokens=dialogue.max_new_tokens, + prompt_code=dialogue.prompt_code, + chat_mode=ChatScene.of_mode(dialogue.chat_mode), + ) chat: BaseChat = await blocking_func_to_async( CFG.SYSTEM_APP, CHAT_FACTORY.get_implementation, diff --git a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py index 11874c983..cb3454828 100644 --- a/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py +++ b/packages/dbgpt-app/src/dbgpt_app/openapi/api_v2.py @@ -30,7 +30,7 @@ from dbgpt_app.openapi.api_v1.api_v1 import ( get_executor, stream_generator, ) -from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene import BaseChat, ChatParam, ChatScene from dbgpt_client.schema import ChatCompletionRequestBody, ChatMode from dbgpt_serve.agent.agents.controller import multi_agents from dbgpt_serve.flow.api.endpoints import get_service @@ -188,16 +188,17 @@ async def get_chat_instance( if not ChatScene.is_valid_mode(dialogue.chat_mode): raise StopAsyncIteration(f"Unsupported Chat Mode,{dialogue.chat_mode}!") - chat_param = { - "chat_session_id": dialogue.conv_uid, - "user_name": dialogue.user_name, - "sys_code": dialogue.sys_code, - "current_user_input": dialogue.single_prompt(), - "select_param": dialogue.chat_param, - "model_name": dialogue.model, - "temperature": dialogue.temperature, - "max_new_tokens": dialogue.max_new_tokens, - } + chat_param = ChatParam( + chat_session_id=dialogue.conv_uid, + user_name=dialogue.user_name, + sys_code=dialogue.sys_code, + current_user_input=dialogue.single_prompt(), + select_param=dialogue.chat_param, + model_name=dialogue.model, + temperature=dialogue.temperature, + max_new_tokens=dialogue.max_new_tokens, + chat_mode=ChatScene.of_mode(dialogue.chat_mode), + ) chat: BaseChat = await blocking_func_to_async( get_executor(), CHAT_FACTORY.get_implementation, diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/__init__.py b/packages/dbgpt-app/src/dbgpt_app/scene/__init__.py index 920ed81e7..0025d44a5 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/__init__.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/__init__.py @@ -1,3 +1,3 @@ from dbgpt_app.scene.base import AppScenePromptTemplateAdapter, ChatScene # noqa: F401 -from dbgpt_app.scene.base_chat import BaseChat # noqa: F401 +from dbgpt_app.scene.base_chat import BaseChat, ChatParam # noqa: F401 from dbgpt_app.scene.chat_factory import ChatFactory # noqa: F401 diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/base.py b/packages/dbgpt-app/src/dbgpt_app/scene/base.py index 08c3cec27..b179a542a 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/base.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/base.py @@ -167,9 +167,6 @@ class AppScenePromptTemplateAdapter(BaseModel): output_parser: Optional[BaseOutputParser] = Field( default=None, description="The output parser of this scene" ) - sep: Optional[str] = Field( - default="###", description="The default separator of this scene" - ) stream_out: Optional[bool] = Field( default=True, description="Whether to stream out" @@ -177,9 +174,6 @@ class AppScenePromptTemplateAdapter(BaseModel): example_selector: Optional[ExampleSelector] = Field( default=None, description="Example selector" ) - need_historical_messages: Optional[bool] = Field( - default=False, description="Whether to need historical messages" - ) temperature: Optional[float] = Field( default=0.6, description="The default temperature of this scene" ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py index 70b2f61c0..b061398ef 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/base_chat.py @@ -2,7 +2,8 @@ import datetime import logging import traceback from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Dict, Optional, Union +from dataclasses import dataclass, field +from typing import Any, AsyncIterator, Dict, Optional, Type, TypeVar, Union from dbgpt._private.config import Config from dbgpt.component import ComponentType, SystemApp @@ -31,6 +32,7 @@ from dbgpt_app.scene.operators.app_operator import ( build_cached_chat_operator, ) from dbgpt_serve.conversation.serve import Serve as ConversationServe +from dbgpt_serve.core.config import BufferWindowGPTsAppMemoryConfig, GPTsAppCommonConfig from dbgpt_serve.prompt.service.service import Service as PromptService from .exceptions import BaseAppException, ContextAppException @@ -38,24 +40,52 @@ from .exceptions import BaseAppException, ContextAppException logger = logging.getLogger(__name__) CFG = Config() +C = TypeVar("C", bound="GPTsAppCommonConfig") + + +@dataclass +class ChatParam: + chat_session_id: str + current_user_input: str + model_name: str + select_param: Any + chat_mode: ChatScene + user_name: str = "" + sys_code: str = "" + app_code: str = "" + temperature: Optional[float] = field(default=None) + max_new_tokens: Optional[int] = field(default=None) + message_version: str = "v2" + model_cache_enable: bool = False + prompt_code: Optional[str] = None + ext_info: Optional[Dict[str, Any]] = None + app_config: Optional[GPTsAppCommonConfig] = None + + def real_app_config(self, type_class: Type[C]) -> C: + if self.app_config is None: + return type_class() + if not isinstance(self.app_config, type_class): + return type_class(**self.app_config.to_dict()) + return self.app_config + def _build_conversation( chat_mode: ChatScene, - chat_param: Dict[str, Any], + chat_param: ChatParam, model_name: str, conv_serve: ConversationServe, ) -> StorageConversation: param_type = "" param_value = "" - if chat_param["select_param"]: + if chat_param.select_param: if len(chat_mode.param_types()) > 0: param_type = chat_mode.param_types()[0] - param_value = chat_param["select_param"] + param_value = chat_param.select_param return StorageConversation( - chat_param["chat_session_id"], + chat_param.chat_session_id, chat_mode=chat_mode.value(), - user_name=chat_param.get("user_name"), - sys_code=chat_param.get("sys_code"), + user_name=chat_param.user_name, + sys_code=chat_param.sys_code, model_name=model_name, param_type=param_type, param_value=param_value, @@ -73,16 +103,17 @@ class BaseChat(ABC): chat_scene: str = None llm_model: Any = None - # By default, keep the last two rounds of conversation records as the context - keep_start_rounds: int = 0 - keep_end_rounds: int = 0 - # Some model not support system role, this config is used to control whether to # convert system message to human message auto_convert_message: bool = True + @classmethod + @abstractmethod + def param_class(cls) -> Type[GPTsAppCommonConfig]: + """Return the parameter class of the chat""" + @trace("BaseChat.__init__") - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat Module Initialization Args: - chat_param: Dict @@ -95,20 +126,20 @@ class BaseChat(ABC): self.app_config = self.system_app.config.configs.get("app_config") self.web_config = self.app_config.service.web self.model_config = self.app_config.models - self.chat_session_id = chat_param["chat_session_id"] - self.chat_mode = chat_param["chat_mode"] - self.current_user_input: str = chat_param["current_user_input"] + self.chat_session_id = chat_param.chat_session_id + self.chat_mode = chat_param.chat_mode + self.current_user_input: str = chat_param.current_user_input self.llm_model = ( - chat_param["model_name"] - if chat_param["model_name"] + chat_param.model_name + if chat_param.model_name else self.model_config.default_llm ) self.llm_echo = False self.worker_manager = self.system_app.get_component( ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory ).create() - self.model_cache_enable = chat_param.get("model_cache_enable", False) - self.prompt_code = chat_param.get("prompt_code", None) + self.model_cache_enable = chat_param.model_cache_enable + self.prompt_code = chat_param.prompt_code self.prompt_template: AppScenePromptTemplateAdapter = ( CFG.prompt_template_registry.get_prompt_template( @@ -135,7 +166,6 @@ class BaseChat(ABC): template_scene=self.prompt_template.template_scene, stream_out=self.prompt_template.stream_out, output_parser=self.prompt_template.output_parser, - need_historical_messages=False, ) self._conv_serve = ConversationServe.get_instance(self.system_app) self.current_message: StorageConversation = _build_conversation( @@ -151,13 +181,9 @@ class BaseChat(ABC): # In v1, we will transform the message to compatible format of specific model # In the future, we will upgrade the message version to v2, and the message # will be compatible with all models - self._message_version = chat_param.get("message_version", "v2") + self._message_version = chat_param.message_version self._chat_param = chat_param - @property - def chat_type(self) -> str: - raise NotImplementedError("Not supported for this chat type.") - @abstractmethod async def generate_input_values(self) -> Dict: """Generate input to LLM @@ -221,6 +247,42 @@ class BaseChat(ABC): speak_to_user = prompt_define_response return speak_to_user + def llm_max_new_tokens(self): + """Get the max new tokens for LLM generation. + + The order of priority is: + 1. chat_param.max_new_tokens(From API) + 2. app_config.max_new_tokens(From config file) + 3. prompt_template.max_new_tokens(From prompt template) + """ + if self._chat_param.max_new_tokens: + return int(self._chat_param.max_new_tokens) + elif self._chat_param.app_config and self._chat_param.app_config.max_new_tokens: + return int(self.app_config.max_new_tokens) + return self.prompt_template.max_new_tokens + + def llm_temperature(self): + """Get the temperature for LLM generation. + + The order of priority is: + 1. chat_param.temperature(From API) + 2. app_config.temperature(From config file) + 3. prompt_template.temperature(From prompt template) + """ + if self._chat_param.temperature is not None: + return float(self._chat_param.temperature) + elif ( + self._chat_param.app_config + and self._chat_param.app_config.temperature is not None + ): + return float(self.app_config.temperature) + return self.prompt_template.temperature + + def memory_config(self): + if self._chat_param.app_config and self._chat_param.app_config.memory: + return self._chat_param.app_config.memory + return BufferWindowGPTsAppMemoryConfig() + async def _build_model_request(self) -> ModelRequest: input_values = await self.generate_input_values() # Load history @@ -232,41 +294,23 @@ class BaseChat(ABC): ) self.current_message.tokens = 0 - keep_start_rounds = ( - self.keep_start_rounds - if self.prompt_template.need_historical_messages - else 0 - ) - keep_end_rounds = ( - self.keep_end_rounds if self.prompt_template.need_historical_messages else 0 - ) req_ctx = ModelRequestContext( stream=self.prompt_template.stream_out, - user_name=self._chat_param.get("user_name"), - sys_code=self._chat_param.get("sys_code"), + user_name=self._chat_param.user_name, + sys_code=self._chat_param.sys_code, chat_mode=self.chat_mode.value(), span_id=root_tracer.get_current_span_id(), ) - temperature = float( - self._chat_param.get("temperature") - if self._chat_param.get("temperature") - else self.prompt_template.temperature - ) - max_new_tokens = int( - self._chat_param.get("max_new_tokens") - if self._chat_param.get("max_new_tokens") - else self.prompt_template.max_new_tokens - ) node = AppChatComposerOperator( model=self.llm_model, - temperature=temperature, - max_new_tokens=max_new_tokens, + temperature=self.llm_temperature(), + max_new_tokens=self.llm_max_new_tokens(), prompt=self.prompt_template.prompt, + llm_client=self.llm_client, + memory=self.memory_config(), message_version=self._message_version, echo=self.llm_echo, streaming=self.prompt_template.stream_out, - keep_start_rounds=keep_start_rounds, - keep_end_rounds=keep_end_rounds, str_history=self.prompt_template.str_history, request_context=req_ctx, ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py index 0f23cce96..9c49c361a 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/chat.py @@ -1,12 +1,15 @@ import json +import logging import os import uuid -from typing import Dict, List +from typing import Dict, List, Type from dbgpt import SystemApp from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import trace from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_dashboard.config import ChatDashboardConfig from dbgpt_app.scene.chat_dashboard.data_loader import DashboardDataLoader from dbgpt_app.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, @@ -14,13 +17,19 @@ from dbgpt_app.scene.chat_dashboard.data_preparation.report_schma import ( ) from dbgpt_serve.datasource.manages import ConnectorManager +logger = logging.getLogger(__name__) + class ChatDashboard(BaseChat): chat_scene: str = ChatScene.ChatDashboard.value() report_name: str """Chat Dashboard to generate dashboard chart""" - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + @classmethod + def param_class(cls) -> Type[ChatDashboardConfig]: + return ChatDashboardConfig + + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat Dashboard Module Initialization Args: - chat_param: Dict @@ -29,17 +38,16 @@ class ChatDashboard(BaseChat): - model_name:(str) llm model name - select_param:(str) dbname """ - self.db_name = chat_param["select_param"] - chat_param["chat_mode"] = ChatScene.ChatDashboard + self.db_name = chat_param.select_param super().__init__(chat_param=chat_param, system_app=system_app) if not self.db_name: raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!") self.db_name = self.db_name - self.report_name = chat_param.get("report_name", "report") + self.report_name = "report" local_db_manager = ConnectorManager.get_instance(self.system_app) self.database = local_db_manager.get_connector(self.db_name) + self.curr_config = chat_param.real_app_config(ChatDashboardConfig) - self.top_k: int = 5 self.dashboard_template = self.__load_dashboard_template(self.report_name) def __load_dashboard_template(self, template_name): @@ -65,12 +73,19 @@ class ChatDashboard(BaseChat): client.get_db_summary, self.db_name, self.current_user_input, - self.top_k, + self.curr_config.schema_retrieve_top_k, ) - print("dashboard vector find tables:{}", table_infos) + logger.info(f"Retrieved table info: {table_infos}") except Exception as e: - print("db summary find error!" + str(e)) - table_infos = self.database.table_simple_info() + logger.error(f"Retrieved table info error: {str(e)}") + table_infos = await blocking_func_to_async( + self._executor, self.database.table_simple_info + ) + if len(table_infos) > self.curr_config.schema_max_tokens: + # Load all tables schema, must be less then schema_max_tokens + # Here we just truncate the table_infos + # TODO: Count the number of tokens by LLMClient + table_infos = table_infos[: self.curr_config.schema_max_tokens] input_values = { "input": self.current_user_input, @@ -82,7 +97,8 @@ class ChatDashboard(BaseChat): return input_values def do_action(self, prompt_response): - ### TODO 记录整体信息,处理成功的,和未成功的分开记录处理 + # TODO: Record the overall information, and record the successful and + # unsuccessful processing separately chart_datas: List[ChartData] = [] dashboard_data_loader = DashboardDataLoader() for chart_item in prompt_response: @@ -102,8 +118,7 @@ class ChatDashboard(BaseChat): ) ) except Exception as e: - # TODO 修复流程 - print(str(e)) + logger.warning(f"Failed to get chart data: {str(e)}") return ReportData( conv_uid=self.chat_session_id, template_name=self.report_name, diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/config.py new file mode 100644 index 000000000..f5792ffb2 --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/config.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass, field + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import GPTsAppCommonConfig + + +@dataclass +class ChatDashboardConfig(GPTsAppCommonConfig): + """Chat Dashboard Configuration""" + + name = ChatScene.ChatDashboard.value() + schema_retrieve_top_k: int = field( + default=10, + metadata={"help": _("The number of tables to retrieve from the database.")}, + ) + schema_max_tokens: int = field( + default=100 * 1024, + metadata={ + "help": _( + "The maximum number of tokens to pass to the model, default 100 * 1024." + "Just work for the schema retrieval failed, and load all tables schema." + ) + }, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/out_parser.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/out_parser.py index 8201b6d0e..66e630a6e 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/out_parser.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/out_parser.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) class ChatDashboardOutputParser(BaseOutputParser): - def __init__(self, is_stream_out: bool, **kwargs): + def __init__(self, is_stream_out: bool = False, **kwargs): super().__init__(is_stream_out=is_stream_out, **kwargs) def parse_prompt_response(self, model_out_text): diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/prompt.py index 28fae5f5e..61ad6e955 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_dashboard/prompt.py @@ -1,7 +1,12 @@ import json from dbgpt._private.config import Config -from dbgpt.core import ChatPromptTemplate, HumanPromptTemplate, SystemPromptTemplate +from dbgpt.core import ( + ChatPromptTemplate, + HumanPromptTemplate, + MessagesPlaceholder, + SystemPromptTemplate, +) from dbgpt_app.scene import AppScenePromptTemplateAdapter, ChatScene from dbgpt_app.scene.chat_dashboard.out_parser import ChatDashboardOutputParser @@ -49,7 +54,6 @@ RESPONSE_FORMAT = [ } ] -PROMPT_NEED_STREAM_OUT = False prompt = ChatPromptTemplate( messages=[ @@ -57,6 +61,7 @@ prompt = ChatPromptTemplate( PROMPT_SCENE_DEFINE + _DEFAULT_TEMPLATE, response_format=json.dumps(RESPONSE_FORMAT, indent=4), ), + MessagesPlaceholder(variable_name="chat_history"), HumanPromptTemplate.from_template("{input}"), ] ) @@ -65,7 +70,6 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatDashboard.value(), stream_out=True, - output_parser=ChatDashboardOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=False, + output_parser=ChatDashboardOutputParser(), ) CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/config.py new file mode 100644 index 000000000..2283a32fa --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/config.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + BufferWindowGPTsAppMemoryConfig, + GPTsAppCommonConfig, +) + + +@dataclass +class ChatExcelConfig(GPTsAppCommonConfig): + """Chat Excel Configuration""" + + name = ChatScene.ChatExcel.value() + duckdb_extensions_dir: List[str] = field( + default_factory=list, + metadata={ + "help": _( + "The directory of the duckdb extensions." + "Duckdb will download the extensions from the internet if not provided." + "This configuration is used to tell duckdb where to find the extensions" + " and avoid downloading. Note that the extensions are platform-specific" + " and version-specific." + ) + }, + ) + force_install: bool = field( + default=False, + metadata={ + "help": _( + "Whether to force install the duckdb extensions. If True, the " + "extensions will be installed even if they are already installed." + ) + }, + ) + + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default_factory=lambda: BufferWindowGPTsAppMemoryConfig( + keep_start_rounds=0, keep_end_rounds=10 + ), + metadata={"help": _("Memory configuration")}, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py index 7cf86822f..4e7bd6c44 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Any, Dict, Union +from typing import Any, Dict, Type, Union from dbgpt import SystemApp from dbgpt.agent.util.api_call import ApiCall @@ -12,6 +12,8 @@ from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.json_utils import EnhancedJSONEncoder from dbgpt.util.tracer import root_tracer, trace from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_data.chat_excel.config import ChatExcelConfig from dbgpt_app.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning from dbgpt_app.scene.chat_data.chat_excel.excel_reader import ExcelReader @@ -22,10 +24,12 @@ class ChatExcel(BaseChat): """a Excel analyzer to analyze Excel Data""" chat_scene: str = ChatScene.ChatExcel.value() - keep_start_rounds = 0 - keep_end_rounds = 2 - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + @classmethod + def param_class(cls) -> Type[ChatExcelConfig]: + return ChatExcelConfig + + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat Excel Module Initialization Args: - chat_param: Dict @@ -35,16 +39,16 @@ class ChatExcel(BaseChat): - select_param:(str) file path """ self.fs_client = FileStorageClient.get_instance(system_app) - self.select_param = chat_param["select_param"] + self.select_param = chat_param.select_param if not self.select_param: raise ValueError("Please upload the Excel document you want to talk to!") - self.model_name = chat_param["model_name"] - chat_param["chat_mode"] = ChatScene.ChatExcel + self.model_name = chat_param.model_name + self.curr_config = chat_param.real_app_config(ChatExcelConfig) self.chat_param = chat_param self._bucket = "dbgpt_app_file" file_path, file_name, database_file_path, database_file_id = self._resolve_path( self.select_param, - chat_param["chat_session_id"], + chat_param.chat_session_id, self.fs_client, self._bucket, ) @@ -53,12 +57,14 @@ class ChatExcel(BaseChat): self._database_file_path = database_file_path self._database_file_id = database_file_id self.excel_reader = ExcelReader( - chat_param["chat_session_id"], + chat_param.chat_session_id, file_path, file_name, read_type="direct", database_name=database_file_path, table_name=self._curr_table, + duckdb_extensions_dir=self.curr_config.duckdb_extensions_dir, + force_install=self.curr_config.force_install, ) self.api_call = ApiCall() @@ -141,20 +147,28 @@ class ChatExcel(BaseChat): if self.has_history_messages(): return None - chat_param = { - "chat_session_id": self.chat_session_id, - "user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analyze!", - "parent_mode": self.chat_mode, - "select_param": self.select_param, - "excel_reader": self.excel_reader, - "model_name": self.model_name, - "user_name": self.chat_param.get("user_name", None), - } - if "temperature" in self._chat_param: - chat_param["temperature"] = self._chat_param["temperature"] - if "max_new_tokens" in self._chat_param: - chat_param["max_new_tokens"] = self._chat_param["max_new_tokens"] - learn_chat = ExcelLearning(**chat_param, system_app=self.system_app) + chat_param = ChatParam( + chat_session_id=self.chat_session_id, + current_user_input="[" + + self.excel_reader.excel_file_name + + "]" + + " Analyze!", + select_param=self.select_param, + chat_mode=ChatScene.ExcelLearning, + model_name=self.model_name, + user_name=self.chat_param.user_name, + sys_code=self.chat_param.sys_code, + ) + if self._chat_param.temperature is not None: + chat_param.temperature = self._chat_param.temperature + if self._chat_param.max_new_tokens is not None: + chat_param.max_new_tokens = self._chat_param.max_new_tokens + learn_chat = ExcelLearning( + chat_param, + system_app=self.system_app, + parent_mode=self.chat_mode, + excel_reader=self.excel_reader, + ) result = await learn_chat.nostream_call() if ( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/out_parser.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/out_parser.py index 6671fcc23..ed537a803 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/out_parser.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/out_parser.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class ChatExcelOutputParser(BaseOutputParser): - def __init__(self, is_stream_out: bool, **kwargs): + def __init__(self, is_stream_out: bool = True, **kwargs): super().__init__(is_stream_out=is_stream_out, **kwargs) def parse_prompt_response(self, model_out_text): diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py index 6877dd539..cfed524b5 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_analyze/prompt.py @@ -227,8 +227,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatExcel.value(), stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=ChatExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=True, + output_parser=ChatExcelOutputParser(), temperature=PROMPT_TEMPERATURE, ) CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/chat.py index 6c9d8d7d2..576f8b51f 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/chat.py @@ -1,12 +1,13 @@ import json -from typing import Any, Dict +from typing import Any, Dict, Type from dbgpt import SystemApp from dbgpt.core.interface.message import ModelMessageRoleType from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.json_utils import EnhancedJSONEncoder from dbgpt.util.tracer import trace -from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene import BaseChat, ChatParam, ChatScene +from dbgpt_serve.core.config import GPTsAppCommonConfig from .out_parser import TransformedExcelResponse from .prompt import USER_INPUT @@ -15,34 +16,21 @@ from .prompt import USER_INPUT class ExcelLearning(BaseChat): chat_scene: str = ChatScene.ExcelLearning.value() + @classmethod + def param_class(cls) -> Type[GPTsAppCommonConfig]: + return GPTsAppCommonConfig + def __init__( self, - chat_session_id, - user_input, - temperature: float, - max_new_tokens: int, + chat_param: ChatParam, + system_app: SystemApp, parent_mode: Any = None, - select_param: str = None, excel_reader: Any = None, - model_name: str = None, - user_name: str = None, - system_app: SystemApp = None, ): - chat_mode = ChatScene.ExcelLearning from ..excel_reader import ExcelReader """ """ self.excel_reader: ExcelReader = excel_reader - chat_param = { - "chat_mode": chat_mode, - "chat_session_id": chat_session_id, - "current_user_input": user_input, - "select_param": select_param, - "model_name": model_name, - "user_name": user_name, - "temperature": temperature, - "max_new_tokens": max_new_tokens, - } self._curr_table = self.excel_reader.temp_table_name super().__init__(chat_param=chat_param, system_app=system_app) if parent_mode: diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/out_parser.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/out_parser.py index 40d4f316b..4226cab89 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) class LearningExcelOutputParser(BaseOutputParser): - def __init__(self, is_stream_out: bool, **kwargs): + def __init__(self, is_stream_out: bool = False, **kwargs): super().__init__(is_stream_out=is_stream_out, **kwargs) self.is_downgraded = False diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/prompt.py index d64032526..5f207b073 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_learning/prompt.py @@ -202,8 +202,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ExcelLearning.value(), stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=LearningExcelOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=False, + output_parser=LearningExcelOutputParser(), temperature=PROMPT_TEMPERATURE, ) CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_reader.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_reader.py index 7fb938850..3a6811375 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_reader.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_data/chat_excel/excel_reader.py @@ -231,6 +231,9 @@ class ExcelReader: read_type: str = "df", database_name: str = ":memory:", table_name: str = "data_analysis_table", + duckdb_extensions_dir: Optional[List[str]] = None, + force_install: bool = False, + show_columns: bool = False, ): if not file_name: file_name = os.path.basename(file_path) @@ -246,6 +249,9 @@ class ExcelReader: self.excel_file_name = file_name + if duckdb_extensions_dir: + self.install_extension(duckdb_extensions_dir, force_install) + if not db_exists: curr_table = self.temp_table_name if read_type == "df": @@ -255,11 +261,12 @@ class ExcelReader: else: curr_table = self.table_name - # Print table schema - result = self.db.sql(f"DESCRIBE {curr_table}") - columns = result.fetchall() - for column in columns: - print(column) + if show_columns: + # Print table schema + result = self.db.sql(f"DESCRIBE {curr_table}") + columns = result.fetchall() + for column in columns: + print(column) def close(self): if self.db: @@ -278,15 +285,18 @@ class ExcelReader: logger.info(f"To be executed SQL: {sql}") if df_res: return self.db.sql(sql).df() - results = self.db.sql(sql) - colunms = [] - for descrip in results.description: - colunms.append(descrip[0]) - return colunms, results.fetchall() + return self._run_sql(sql) except Exception as e: logger.error(f"excel sql run error!, {str(e)}") raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}") + def _run_sql(self, sql: str): + results = self.db.sql(sql) + columns = [] + for desc in results.description: + columns.append(desc[0]) + return columns, results.fetchall() + def get_df_by_sql_ex(self, sql: str, table_name: Optional[str] = None): table_name = table_name or self.table_name return self.run(sql, table_name, df_res=True) @@ -425,3 +435,62 @@ AND dc.schema_name = 'main'; ) return new_table + + def install_extension( + self, duckdb_extensions_dir: Optional[List[str]], force_install: bool = False + ) -> int: + if not duckdb_extensions_dir: + return 0 + cnt = 0 + for extension_dir in duckdb_extensions_dir: + if not os.path.exists(extension_dir): + logger.warning(f"Extension directory not exists: {extension_dir}") + continue + extension_files = [ + os.path.join(extension_dir, f) + for f in os.listdir(extension_dir) + if f.endswith(".duckdb_extension.gz") or f.endswith(".duckdb_extension") + ] + _, extensions = self._query_extension() + installed_extensions = [ext[0] for ext in extensions if ext[1]] + for extension_file in extension_files: + try: + extension_name = os.path.basename(extension_file).split(".")[0] + if not force_install and extension_name in installed_extensions: + logger.info( + f"Extension {extension_name} has been installed, skip" + ) + continue + self.db.install_extension( + extension_file, force_install=force_install + ) + self.db.load_extension(extension_name) + cnt += 1 + logger.info(f"Installed extension {extension_name} for DuckDB") + except Exception as e: + logger.warning( + f"Error while installing extension {extension_file}: {str(e)}" + ) + logger.debug(f"Installed extensions: {cnt}") + self.list_extensions() + return cnt + + def list_extensions(self, stdout=False): + from prettytable import PrettyTable + + table = PrettyTable() + columns, datas = self._query_extension() + table.field_names = columns + for data in datas: + table.add_row(data) + show_str = "DuckDB Extensions:\n" + show_str += table.get_formatted_string() + if stdout: + print(show_str) + else: + logger.info(show_str) + + def _query_extension(self): + return self._run_sql( + "SELECT extension_name, installed, description FROM duckdb_extensions();" + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/chat.py index 4baecac1d..15319e4a1 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/chat.py @@ -1,19 +1,29 @@ -from typing import Dict +import logging +from typing import Dict, Type from dbgpt import SystemApp from dbgpt.agent.util.api_call import ApiCall from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import root_tracer, trace from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_db.auto_execute.config import ChatWithDBExecuteConfig +from dbgpt_serve.core.config import GPTsAppCommonConfig from dbgpt_serve.datasource.manages import ConnectorManager +logger = logging.getLogger(__name__) + class ChatWithDbAutoExecute(BaseChat): chat_scene: str = ChatScene.ChatWithDbExecute.value() """Number of results to return from the query""" - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + @classmethod + def param_class(cls) -> Type[GPTsAppCommonConfig]: + return ChatWithDBExecuteConfig + + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat Data Module Initialization Args: - chat_param: Dict @@ -22,10 +32,8 @@ class ChatWithDbAutoExecute(BaseChat): - model_name:(str) llm model name - select_param:(str) dbname """ - chat_mode = ChatScene.ChatWithDbExecute - self.db_name = chat_param["select_param"] - chat_param["chat_mode"] = chat_mode - """ """ + self.db_name = chat_param.select_param + self.curr_config = chat_param.real_app_config(ChatWithDBExecuteConfig) super().__init__(chat_param=chat_param, system_app=system_app) if not self.db_name: raise ValueError( @@ -36,7 +44,6 @@ class ChatWithDbAutoExecute(BaseChat): ): local_db_manager = ConnectorManager.get_instance(self.system_app) self.database = local_db_manager.get_connector(self.db_name) - self.top_k: int = 50 self.api_call = ApiCall() @trace() @@ -49,7 +56,6 @@ class ChatWithDbAutoExecute(BaseChat): except ImportError: raise ValueError("Could not import DBSummaryClient. ") client = DBSummaryClient(system_app=self.system_app) - table_infos = None try: with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"): table_infos = await blocking_func_to_async( @@ -57,19 +63,23 @@ class ChatWithDbAutoExecute(BaseChat): client.get_db_summary, self.db_name, self.current_user_input, - self.app_config.rag.similarity_top_k, + self.curr_config.schema_retrieve_top_k, ) except Exception as e: - print("db summary find error!" + str(e)) - if not table_infos: + logger.error(f"Retrieved table info error: {str(e)}") table_infos = await blocking_func_to_async( self._executor, self.database.table_simple_info ) + if len(table_infos) > self.curr_config.schema_max_tokens: + # Load all tables schema, must be less then schema_max_tokens + # Here we just truncate the table_infos + # TODO: Count the number of tokens by LLMClient + table_infos = table_infos[: self.curr_config.schema_max_tokens] input_values = { "db_name": self.db_name, "user_input": self.current_user_input, - "top_k": str(self.top_k), + "top_k": self.curr_config.max_num_results, "dialect": self.database.dialect, "table_info": table_infos, "display_type": self._generate_numbered_list(), diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/config.py new file mode 100644 index 000000000..840bc0951 --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/config.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + BufferWindowGPTsAppMemoryConfig, + GPTsAppCommonConfig, +) + + +@dataclass +class ChatWithDBExecuteConfig(GPTsAppCommonConfig): + """Chat With DB Execute Configuration""" + + name = ChatScene.ChatWithDbExecute.value() + schema_retrieve_top_k: int = field( + default=10, + metadata={"help": _("The number of tables to retrieve from the database.")}, + ) + schema_max_tokens: int = field( + default=100 * 1024, + metadata={ + "help": _( + "The maximum number of tokens to pass to the model, default 100 * 1024." + "Just work for the schema retrieval failed, and load all tables schema." + ) + }, + ) + max_num_results: int = field( + default=50, + metadata={"help": _("The maximum number of results to return from the query.")}, + ) + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default_factory=lambda: BufferWindowGPTsAppMemoryConfig( + keep_start_rounds=0, keep_end_rounds=10 + ), + metadata={"help": _("Memory configuration")}, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/out_parser.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/out_parser.py index 72759e9c3..30fcd3749 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/out_parser.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/out_parser.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) class DbChatOutputParser(BaseOutputParser): - def __init__(self, is_stream_out: bool, **kwargs): + def __init__(self, is_stream_out: bool = False, **kwargs): super().__init__(is_stream_out=is_stream_out, **kwargs) def is_sql_statement(self, statement): diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt.py index a45dac679..d1157467b 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt.py @@ -96,8 +96,6 @@ RESPONSE_FORMAT_SIMPLE = { } -PROMPT_NEED_STREAM_OUT = False - # Temperature is a configuration hyperparameter that controls the randomness of # language model output. # A high temperature produces more unpredictable and creative results, while a low @@ -124,8 +122,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatWithDbExecute.value(), stream_out=True, - output_parser=DbChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), + output_parser=DbChatOutputParser(), temperature=PROMPT_TEMPERATURE, - need_historical_messages=False, ) CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt_baichuan.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt_baichuan.py deleted file mode 100644 index 1ea1893c0..000000000 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/auto_execute/prompt_baichuan.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import json - -from dbgpt._private.config import Config -from dbgpt.core.interface.prompt import PromptTemplate -from dbgpt_app.scene import ChatScene -from dbgpt_app.scene.chat_db.auto_execute.out_parser import DbChatOutputParser - -CFG = Config() - -PROMPT_SCENE_DEFINE = None - -_DEFAULT_TEMPLATE = """ -你是一个 SQL 专家,给你一个用户的问题,你会生成一条对应的 {dialect} 语法的 SQL 语句。 - -如果用户没有在问题中指定 sql 返回多少条数据,那么你生成的 sql 最多返回 {top_k} 条数据。 -你应该尽可能少地使用表。 - -已知表结构信息如下: -{table_info} - -注意: -1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,\ -请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 -2. 不要查询不存在的列,注意哪一列位于哪张表中。 -3. 使用 json 格式回答,确保你的回答是必须是正确的 json 格式,\ -并且能被 python 语言的 `json.loads` 库解析, 格式如下: -{response} -""" - -RESPONSE_FORMAT_SIMPLE = { - "thoughts": "对用户说的想法摘要", - "sql": "生成的将被执行的 SQL", -} - - -PROMPT_NEED_STREAM_OUT = False - -# Temperature is a configuration hyperparameter that controls the randomness of -# language model output. -# A high temperature produces more unpredictable and creative results, while a low -# temperature produces more common and conservative output. -# For example, if you adjust the temperature to 0.5, the model will usually generate -# text that is more predictable and less creative than if you set the temperature to -# 1.0. -PROMPT_TEMPERATURE = 0.5 - -prompt = PromptTemplate( - template_scene=ChatScene.ChatWithDbExecute.value(), - input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), - template_is_strict=False, - template_define=PROMPT_SCENE_DEFINE, - template=_DEFAULT_TEMPLATE, - stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=DbChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - # example_selector=sql_data_example, - temperature=PROMPT_TEMPERATURE, -) - -CFG.prompt_template_registry.register( - prompt, - language=CFG.LANGUAGE, - is_default=False, - model_names=["baichuan-13b", "baichuan-7b"], -) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/chat.py index 0f5cbd5db..13f77fb4c 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/chat.py @@ -1,20 +1,24 @@ -from typing import Dict +from typing import Dict, Type from dbgpt.component import SystemApp, logger from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.tracer import trace from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_db.professional_qa.config import ChatWithDBQAConfig from dbgpt_serve.datasource.manages import ConnectorManager class ChatWithDbQA(BaseChat): + """As a DBA, Chat DB Module, chat with combine DB meta schema""" + chat_scene: str = ChatScene.ChatWithDbQA.value() - keep_end_rounds = 5 + @classmethod + def param_class(cls) -> Type[ChatWithDBQAConfig]: + return ChatWithDBQAConfig - """As a DBA, Chat DB Module, chat with combine DB meta schema """ - - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat DB Module Initialization Args: - chat_param: Dict @@ -23,8 +27,8 @@ class ChatWithDbQA(BaseChat): - model_name:(str) llm model name - select_param:(str) dbname """ - self.db_name = chat_param["select_param"] - chat_param["chat_mode"] = ChatScene.ChatWithDbQA + self.db_name = chat_param.select_param + self.curr_config = chat_param.real_app_config(ChatWithDBQAConfig) super().__init__(chat_param=chat_param, system_app=system_app) if self.db_name: @@ -38,12 +42,8 @@ class ChatWithDbQA(BaseChat): self.tables["edge_tables"] ) else: - print(self.database.db_type) - self.top_k = ( - self.app_config.rag.similarity_top_k - if len(self.tables) > self.app_config.rag.similarity_top_k - else len(self.tables) - ) + logger.info(f"Dialect: {self.database.db_type}") + self.top_k = self.curr_config.schema_retrieve_top_k @trace() async def generate_input_values(self) -> Dict: @@ -51,6 +51,7 @@ class ChatWithDbQA(BaseChat): from dbgpt_serve.datasource.service.db_summary_client import DBSummaryClient except ImportError: raise ValueError("Could not import DBSummaryClient. ") + table_infos = None if self.db_name: client = DBSummaryClient(system_app=self.system_app) try: @@ -62,16 +63,18 @@ class ChatWithDbQA(BaseChat): self.top_k, ) except Exception as e: - logger.error("db summary find error!" + str(e)) - # table_infos = self.database.table_simple_info() + logger.error(f"Retrieved table info error: {str(e)}") table_infos = await blocking_func_to_async( self._executor, self.database.table_simple_info ) + if len(table_infos) > self.curr_config.schema_max_tokens: + # Load all tables schema, must be less then schema_max_tokens + # Here we just truncate the table_infos + # TODO: Count the number of tokens by LLMClient + table_infos = table_infos[: self.curr_config.schema_max_tokens] input_values = { "input": self.current_user_input, - # "top_k": str(self.top_k), - # "dialect": dialect, "table_info": table_infos, } return input_values diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/config.py new file mode 100644 index 000000000..fccc95e9f --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/config.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + BufferWindowGPTsAppMemoryConfig, + GPTsAppCommonConfig, +) + + +@dataclass +class ChatWithDBQAConfig(GPTsAppCommonConfig): + """Chat With DB QA Configuration""" + + name = ChatScene.ChatWithDbQA.value() + schema_retrieve_top_k: int = field( + default=10, + metadata={"help": _("The number of tables to retrieve from the database.")}, + ) + schema_max_tokens: int = field( + default=100 * 1024, + metadata={ + "help": _( + "The maximum number of tokens to pass to the model, default 100 * 1024." + "Just work for the schema retrieval failed, and load all tables schema." + ) + }, + ) + max_num_results: int = field( + default=50, + metadata={"help": _("The maximum number of results to return from the query.")}, + ) + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default_factory=lambda: BufferWindowGPTsAppMemoryConfig( + keep_start_rounds=0, keep_end_rounds=10 + ), + metadata={"help": _("Memory configuration")}, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/prompt.py index df0ad7e11..c22d7da17 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_db/professional_qa/prompt.py @@ -55,8 +55,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatWithDbQA.value(), stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=True, + output_parser=NormalChatOutputParser(), ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_factory.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_factory.py index eaf576846..02a3e5b3a 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_factory.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_factory.py @@ -1,11 +1,15 @@ +from dbgpt.component import SystemApp from dbgpt.util.singleton import Singleton from dbgpt.util.tracer import root_tracer -from dbgpt_app.scene.base_chat import BaseChat +from dbgpt_app.scene.base_chat import BaseChat, ChatParam +from dbgpt_serve.core.config import parse_config class ChatFactory(metaclass=Singleton): @staticmethod - def get_implementation(chat_mode, system_app, **kwargs): + def get_implementation( + chat_mode: str, system_app: SystemApp, chat_param: ChatParam, **kwargs + ): # Lazy loading from dbgpt_app.scene.chat_dashboard.chat import ChatDashboard # noqa: F401 from dbgpt_app.scene.chat_dashboard.prompt import prompt # noqa: F401 @@ -49,7 +53,13 @@ class ChatFactory(metaclass=Singleton): with root_tracer.start_span( "get_implementation_of_chat", metadata=metadata ): - implementation = cls(**kwargs, system_app=system_app) + config = parse_config( + system_app, chat_mode, type_class=cls.param_class() + ) + chat_param.app_config = config + implementation = cls( + **kwargs, chat_param=chat_param, system_app=system_app + ) if implementation is None: raise Exception(f"Invalid implementation name:{chat_mode}") return implementation diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/chat.py index a8e0646fe..8e5fdae84 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/chat.py @@ -1,21 +1,28 @@ -from typing import Dict +from typing import Type +from dbgpt import SystemApp from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_serve.core.config import GPTsAppCommonConfig class ExtractRefineSummary(BaseChat): - chat_scene: str = ChatScene.ExtractRefineSummary.value() - """extract final summary by llm""" - def __init__(self, chat_param: Dict): + chat_scene: str = ChatScene.ExtractRefineSummary.value() + + @classmethod + def param_class(cls) -> Type[GPTsAppCommonConfig]: + return GPTsAppCommonConfig + + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """ """ - chat_param["chat_mode"] = ChatScene.ExtractRefineSummary super().__init__( chat_param=chat_param, + system_app=system_app, ) - self.existing_answer = chat_param["select_param"] + self.existing_answer = chat_param.select_param async def generate_input_values(self): input_values = { @@ -27,7 +34,3 @@ class ExtractRefineSummary(BaseChat): def stream_plugin_call(self, text): """return summary label""" return f"{text}" - - @property - def chat_type(self) -> str: - return ChatScene.ExtractRefineSummary.value diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/out_parser.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/out_parser.py index 7b79ff7fb..e923c981f 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/out_parser.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/out_parser.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) class ExtractRefineSummaryParser(BaseOutputParser): - def __init__(self, is_stream_out: bool, **kwargs): + def __init__(self, is_stream_out: bool = True, **kwargs): super().__init__(is_stream_out=is_stream_out, **kwargs) def parse_prompt_response( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/prompt.py index 79927c507..b75da4749 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/refine_summary/prompt.py @@ -29,8 +29,6 @@ _DEFAULT_TEMPLATE = ( PROMPT_RESPONSE = """""" -PROMPT_NEED_NEED_STREAM_OUT = True - prompt = ChatPromptTemplate( messages=[ # SystemPromptTemplate.from_template(PROMPT_SCENE_DEFINE), @@ -41,9 +39,8 @@ prompt = ChatPromptTemplate( prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ExtractRefineSummary.value(), - stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=ExtractRefineSummaryParser(is_stream_out=PROMPT_NEED_NEED_STREAM_OUT), - need_historical_messages=False, + stream_out=True, + output_parser=ExtractRefineSummaryParser(), ) CFG.prompt_template_registry.register(prompt_adapter, is_default=True) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py index be453003c..a11d98de1 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py @@ -1,7 +1,7 @@ import json import os from functools import reduce -from typing import Dict, List +from typing import Dict, List, Type from dbgpt import SystemApp from dbgpt.core import ( @@ -17,6 +17,8 @@ from dbgpt.util.tracer import root_tracer, trace from dbgpt_app.knowledge.request.request import KnowledgeSpaceRequest from dbgpt_app.knowledge.service import KnowledgeService from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_knowledge.v1.config import ChatKnowledgeConfig from dbgpt_serve.rag.models.chunk_db import DocumentChunkDao, DocumentChunkEntity from dbgpt_serve.rag.models.document_db import ( KnowledgeDocumentDao, @@ -26,10 +28,15 @@ from dbgpt_serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever class ChatKnowledge(BaseChat): - chat_scene: str = ChatScene.ChatKnowledge.value() """KBQA Chat Module""" - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + chat_scene: str = ChatScene.ChatKnowledge.value() + + @classmethod + def param_class(cls) -> Type[ChatKnowledgeConfig]: + return ChatKnowledgeConfig + + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """Chat Knowledge Module Initialization Args: - chat_param: Dict @@ -40,8 +47,8 @@ class ChatKnowledge(BaseChat): """ from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory - self.knowledge_space = chat_param["select_param"] - chat_param["chat_mode"] = ChatScene.ChatKnowledge + self.curr_config = chat_param.real_app_config(ChatKnowledgeConfig) + self.knowledge_space = chat_param.select_param super().__init__(chat_param=chat_param, system_app=system_app) from dbgpt_serve.rag.models.models import ( KnowledgeSpaceDao, @@ -55,16 +62,9 @@ class ChatKnowledge(BaseChat): raise Exception(f"have not found knowledge space:{self.knowledge_space}") self.rag_config = self.app_config.rag self.space_context = self.get_space_context(space.name) - self.top_k = ( - self.get_knowledge_search_top_size(space.name) - if self.space_context is None - else int(self.space_context["embedding"]["topk"]) - ) - self.recall_score = ( - self.rag_config.similarity_score_threshold - if self.space_context is None - else float(self.space_context["embedding"]["recall_score"]) - ) + + self.top_k = self.get_knowledge_search_top_size(space.name) + self.recall_score = self.get_similarity_score_threshold() query_rewrite = None if self.rag_config.query_rewrite: @@ -81,13 +81,14 @@ class ChatKnowledge(BaseChat): rerank_embeddings = RerankEmbeddingFactory.get_instance( self.system_app ).create() - reranker = RerankEmbeddingsRanker( - rerank_embeddings, topk=self.rag_config.rerank_top_k - ) - if retriever_top_k < self.rag_config.rerank_top_k or retriever_top_k < 20: + rerank_top_k = self.curr_config.knowledge_retrieve_rerank_top_k + if not rerank_top_k: + rerank_top_k = self.rag_config.rerank_top_k + reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=rerank_top_k) + if retriever_top_k < rerank_top_k or retriever_top_k < 20: # We use reranker, so if the top_k is less than 20, # we need to set it to 20 - retriever_top_k = max(self.rag_config.rerank_top_k, 20) + retriever_top_k = max(rerank_top_k, 20) self._space_retriever = KnowledgeSpaceRetriever( space_id=space.id, embedding_model=self.model_config.default_embedding, @@ -214,10 +215,6 @@ class ChatKnowledge(BaseChat): reference = html.decode("utf-8") return reference.replace("\\n", "") - @property - def chat_type(self) -> str: - return ChatScene.ChatKnowledge.value() - def get_space_context_by_id(self, space_id): service = KnowledgeService() return service.get_space_context_by_space_id(space_id) @@ -227,6 +224,9 @@ class ChatKnowledge(BaseChat): return service.get_space_context(space_name) def get_knowledge_search_top_size(self, space_name) -> int: + if self.space_context: + return int(self.space_context["embedding"]["topk"]) + service = KnowledgeService() request = KnowledgeSpaceRequest(name=space_name) spaces = service.get_knowledge_space(request) @@ -235,9 +235,18 @@ class ChatKnowledge(BaseChat): if spaces[0].vector_type in graph_storages: return self.rag_config.kg_chunk_search_top_k + if self.curr_config.knowledge_retrieve_top_k: + return self.curr_config.knowledge_retrieve_top_k return self.rag_config.similarity_top_k + def get_similarity_score_threshold(self): + if self.space_context: + return float(self.space_context["embedding"]["recall_score"]) + if self.curr_config.similarity_score_threshold >= 0: + return self.curr_config.similarity_score_threshold + return self.rag_config.similarity_score_threshold + async def execute_similar_search(self, query): """execute similarity search""" with root_tracer.start_span( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/config.py new file mode 100644 index 000000000..7c068793b --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/config.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + BufferWindowGPTsAppMemoryConfig, + GPTsAppCommonConfig, +) + + +@dataclass +class ChatKnowledgeConfig(GPTsAppCommonConfig): + """Chat Knowledge Configuration""" + + name = ChatScene.ChatKnowledge.value() + knowledge_retrieve_top_k: int = field( + default=10, + metadata={ + "help": _("The number of chunks to retrieve from the knowledge space.") + }, + ) + knowledge_retrieve_rerank_top_k: int = field( + default=10, + metadata={"help": _("The number of chunks after reranking.")}, + ) + similarity_score_threshold: float = field( + default=0.0, + metadata={"help": _("The minimum similarity score to return from the query.")}, + ) + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default_factory=lambda: BufferWindowGPTsAppMemoryConfig( + keep_start_rounds=0, keep_end_rounds=10 + ), + metadata={"help": _("Memory configuration")}, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt.py index 2b9f41cd6..cb6abd37a 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt.py @@ -67,8 +67,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatKnowledge.value(), stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=False, + output_parser=NormalChatOutputParser(), ) CFG.prompt_template_registry.register( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt_chatglm.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt_chatglm.py index 1e3c9310f..767db644c 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt_chatglm.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/prompt_chatglm.py @@ -54,8 +54,7 @@ prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatKnowledge.value(), stream_out=True, - output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=False, + output_parser=NormalChatOutputParser(), ) CFG.prompt_template_registry.register( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/chat.py index ee43bf08d..b460a2707 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/chat.py @@ -1,27 +1,24 @@ -from typing import Dict +from typing import Dict, Type from dbgpt import SystemApp from dbgpt.util.tracer import trace from dbgpt_app.scene import BaseChat, ChatScene +from dbgpt_app.scene.base_chat import ChatParam +from dbgpt_app.scene.chat_normal.config import ChatNormalConfig class ChatNormal(BaseChat): chat_scene: str = ChatScene.ChatNormal.value() - keep_end_rounds: int = 10 + @classmethod + def param_class(cls) -> Type[ChatNormalConfig]: + return ChatNormalConfig - """Number of results to return from the query""" - - def __init__(self, chat_param: Dict, system_app: SystemApp = None): + def __init__(self, chat_param: ChatParam, system_app: SystemApp): """ """ - chat_param["chat_mode"] = ChatScene.ChatNormal super().__init__(chat_param=chat_param, system_app=system_app) @trace() async def generate_input_values(self) -> Dict: input_values = {"input": self.current_user_input} return input_values - - @property - def chat_type(self) -> str: - return ChatScene.ChatNormal.value diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/config.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/config.py new file mode 100644 index 000000000..1fa0c8639 --- /dev/null +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/config.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.util.i18n_utils import _ +from dbgpt_app.scene import ChatScene +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + GPTsAppCommonConfig, + TokenBufferGPTsAppMemoryConfig, +) + + +@dataclass +class ChatNormalConfig(GPTsAppCommonConfig): + """Chat Normal Configuration""" + + name = ChatScene.ChatNormal.value() + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default_factory=lambda: TokenBufferGPTsAppMemoryConfig( + max_token_limit=20 * 1024 + ), + metadata={"help": _("Memory configuration")}, + ) diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/prompt.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/prompt.py index 9b1f59f2f..88f17f864 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/prompt.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_normal/prompt.py @@ -17,7 +17,6 @@ PROMPT_SCENE_DEFINE = ( PROMPT_SCENE_DEFINE_ZH if CFG.LANGUAGE == "zh" else PROMPT_SCENE_DEFINE_EN ) -PROMPT_NEED_STREAM_OUT = True prompt = ChatPromptTemplate( messages=[ @@ -30,9 +29,8 @@ prompt = ChatPromptTemplate( prompt_adapter = AppScenePromptTemplateAdapter( prompt=prompt, template_scene=ChatScene.ChatNormal.value(), - stream_out=PROMPT_NEED_STREAM_OUT, - output_parser=NormalChatOutputParser(is_stream_out=PROMPT_NEED_STREAM_OUT), - need_historical_messages=True, + stream_out=True, + output_parser=NormalChatOutputParser(), ) CFG.prompt_template_registry.register( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/operators/app_operator.py b/packages/dbgpt-app/src/dbgpt_app/scene/operators/app_operator.py index eb2335a40..4ad15ff7a 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/operators/app_operator.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/operators/app_operator.py @@ -21,6 +21,7 @@ from dbgpt.core.awel import ( from dbgpt.core.operators import ( BufferedConversationMapperOperator, HistoryPromptBuilderOperator, + TokenBufferedConversationMapperOperator, ) from dbgpt.model.operators import LLMOperator, StreamingLLMOperator from dbgpt.storage.cache.operators import ( @@ -31,6 +32,11 @@ from dbgpt.storage.cache.operators import ( ModelSaveCacheOperator, ModelStreamSaveCacheOperator, ) +from dbgpt_serve.core.config import ( + BaseGPTsAppMemoryConfig, + BufferWindowGPTsAppMemoryConfig, + TokenBufferGPTsAppMemoryConfig, +) @dataclasses.dataclass @@ -53,13 +59,12 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]): temperature: float, max_new_tokens: int, prompt: ChatPromptTemplate, + llm_client: LLMClient, + memory: BaseGPTsAppMemoryConfig, message_version: str = "v2", echo: bool = False, streaming: bool = True, history_key: str = "chat_history", - history_merge_mode: str = "window", - keep_start_rounds: Optional[int] = None, - keep_end_rounds: Optional[int] = None, str_history: bool = False, request_context: ModelRequestContext = None, **kwargs, @@ -68,10 +73,9 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]): if not request_context: request_context = ModelRequestContext(stream=streaming) self._prompt_template = prompt + self._llm_client = llm_client self._history_key = history_key - self._history_merge_mode = history_merge_mode - self._keep_start_rounds = keep_start_rounds - self._keep_end_rounds = keep_end_rounds + self._memory = memory self._str_history = str_history self._model_name = model self._temperature = temperature @@ -104,10 +108,21 @@ class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]): with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag: input_task = InputOperator(input_source=SimpleCallDataInputSource()) # History transform task - history_transform_task = BufferedConversationMapperOperator( - keep_start_rounds=self._keep_start_rounds, - keep_end_rounds=self._keep_end_rounds, - ) + if isinstance(self._memory, BufferWindowGPTsAppMemoryConfig): + history_transform_task = BufferedConversationMapperOperator( + keep_start_rounds=self._memory.keep_start_rounds, + keep_end_rounds=self._memory.keep_end_rounds, + ) + elif isinstance(self._memory, TokenBufferGPTsAppMemoryConfig): + history_transform_task = TokenBufferedConversationMapperOperator( + model=self._model_name, + llm_client=self._llm_client, + max_token_limit=self._memory.max_token_limit, + ) + else: + raise ValueError( + f"Unsupported memory configuration: {self._memory.__class__}" + ) history_prompt_build_task = HistoryPromptBuilderOperator( prompt=self._prompt_template, history_key=self._history_key, diff --git a/packages/dbgpt-core/src/dbgpt/configs/model_config.py b/packages/dbgpt-core/src/dbgpt/configs/model_config.py index 093efa754..5616213c5 100644 --- a/packages/dbgpt-core/src/dbgpt/configs/model_config.py +++ b/packages/dbgpt-core/src/dbgpt/configs/model_config.py @@ -349,3 +349,6 @@ EMBEDDING_MODEL_CONFIG = { KNOWLEDGE_UPLOAD_ROOT_PATH = DATA_DIR +KNOWLEDGE_CACHE_ROOT_PATH = os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, "_knowledge_cache_" +) diff --git a/packages/dbgpt-core/src/dbgpt/core/interface/file.py b/packages/dbgpt-core/src/dbgpt/core/interface/file.py index c275902d6..855043a75 100644 --- a/packages/dbgpt-core/src/dbgpt/core/interface/file.py +++ b/packages/dbgpt-core/src/dbgpt/core/interface/file.py @@ -607,6 +607,7 @@ class FileStorageClient(BaseComponent): if dest_path: target_path = dest_path elif dest_dir: + os.makedirs(dest_dir, exist_ok=True) target_path = os.path.join(dest_dir, file_metadata.file_id + extension) else: from pathlib import Path diff --git a/packages/dbgpt-core/src/dbgpt/core/interface/llm.py b/packages/dbgpt-core/src/dbgpt/core/interface/llm.py index 8042fde23..66ec91459 100644 --- a/packages/dbgpt-core/src/dbgpt/core/interface/llm.py +++ b/packages/dbgpt-core/src/dbgpt/core/interface/llm.py @@ -346,7 +346,16 @@ class ModelOutput: def to_dict(self) -> Dict: """Convert the model output to dict.""" - return asdict(self) + text = self.gen_text_with_thinking() + return { + "error_code": self.error_code, + "text": text, + "incremental": self.incremental, + "model_context": self.model_context, + "finish_reason": self.finish_reason, + "usage": self.usage, + "metrics": self.metrics, + } @property def success(self) -> bool: diff --git a/packages/dbgpt-core/src/dbgpt/util/config_utils.py b/packages/dbgpt-core/src/dbgpt/util/config_utils.py index 66172c2e4..bf8440133 100644 --- a/packages/dbgpt-core/src/dbgpt/util/config_utils.py +++ b/packages/dbgpt-core/src/dbgpt/util/config_utils.py @@ -1,6 +1,8 @@ import os from functools import cache -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional, Type, TypeVar, cast + +T = TypeVar("T") class AppConfig: @@ -29,6 +31,40 @@ class AppConfig: """ return self.configs.get(key, default) + def get_typed( + self, key: str, type_class: Type[T], default: Optional[T] = None + ) -> T: + """Get config value by key with specific type + Args: + key (str): The key of config + type_class (Type[T]): The expected return type + default (Optional[T], optional): The default value if key not found. + Defaults to None. + Returns: + T: The value of config with specified type + Raises: + TypeError: If the value is not of the expected type and cannot be converted + """ + value = self.configs.get(key, default) + if value is None: + return cast(T, value) + + # If the value is already of the expected type, return it directly + if isinstance(value, type_class): + return value + + # Try to convert the value to the expected type + try: + if type_class is bool and isinstance(value, str): + # Handle boolean values as strings + return cast(T, value.lower() in ("true", "yes", "1", "y")) + # Convert the value to the expected type + return type_class(value) + except (ValueError, TypeError): + raise TypeError( + f"Cannot convert config value '{value}' to type {type_class.__name__}" + ) + @cache def get_all_by_prefix(self, prefix) -> Dict[str, Any]: """Get all config values by prefix diff --git a/packages/dbgpt-core/src/dbgpt/util/module_utils.py b/packages/dbgpt-core/src/dbgpt/util/module_utils.py index a819b7137..d4ac6a612 100644 --- a/packages/dbgpt-core/src/dbgpt/util/module_utils.py +++ b/packages/dbgpt-core/src/dbgpt/util/module_utils.py @@ -169,8 +169,8 @@ class ModelScanner(Generic[T]): logger.warning(f"Directory not found: {base_path}") return results - # If specific files are provided, only scan those files - if config.specific_files: + # If specific files are provided, only scan those files, but not recursively + if config.specific_files and not config.recursive: for file_name in config.specific_files: # Construct the full file path file_path = base_dir / f"{file_name}.py" @@ -199,6 +199,7 @@ class ModelScanner(Generic[T]): # Regular directory scanning pattern = "**/*.py" if config.recursive else "*.py" + specific_files = set(config.specific_files or []) for item in base_dir.glob(pattern): if item.name.startswith("__"): continue @@ -207,6 +208,8 @@ class ModelScanner(Generic[T]): if self._should_skip_file(item.name, config.skip_files): logger.debug(f"Skipping file {item.name} due to skip_files pattern") continue + if specific_files and item.stem not in specific_files: + continue try: # Get the module name relative to the base module diff --git a/packages/dbgpt-core/src/dbgpt/util/parameter_utils.py b/packages/dbgpt-core/src/dbgpt/util/parameter_utils.py index f40b946f4..85112a039 100644 --- a/packages/dbgpt-core/src/dbgpt/util/parameter_utils.py +++ b/packages/dbgpt-core/src/dbgpt/util/parameter_utils.py @@ -98,7 +98,11 @@ class BaseParameters: """ all_field_names = {f.name for f in fields(cls)} if ignore_extra_fields: - data = {key: value for key, value in data.items() if key in all_field_names} + data = { + key: value + for key, value in data.items() + if key in all_field_names and value is not None + } else: extra_fields = set(data.keys()) - all_field_names if extra_fields: diff --git a/packages/dbgpt-serve/src/dbgpt_serve/core/config.py b/packages/dbgpt-serve/src/dbgpt_serve/core/config.py index 4fead4f69..2f2c3ec74 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/core/config.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/core/config.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, field -from typing import Optional +from functools import cache +from typing import List, Optional, Type -from dbgpt.component import AppConfig +from dbgpt.component import AppConfig, SystemApp from dbgpt.util import BaseParameters, RegisterParameters from dbgpt.util.i18n_utils import _ @@ -44,3 +45,103 @@ class BaseServeConfig(BaseParameters, RegisterParameters): if k not in config_dict and k[len(global_prefix) :] in cls().__dict__: config_dict[k[len(global_prefix) :]] = v return cls(**config_dict) + + +@dataclass +class BaseGPTsAppMemoryConfig(BaseParameters, RegisterParameters): + __type__ = "___memory_placeholder___" + + +@dataclass +class BufferWindowGPTsAppMemoryConfig(BaseGPTsAppMemoryConfig): + """Buffer window memory configuration. + + This configuration is used to control the buffer window memory. + """ + + __type__ = "window" + keep_start_rounds: int = field( + default=0, + metadata={"help": _("The number of start rounds to keep in memory")}, + ) + keep_end_rounds: int = field( + default=0, + metadata={"help": _("The number of end rounds to keep in memory")}, + ) + + +@dataclass +class TokenBufferGPTsAppMemoryConfig(BaseGPTsAppMemoryConfig): + """Token buffer memory configuration. + + This configuration is used to control the token buffer memory. + """ + + __type__ = "token" + max_token_limit: int = field( + default=100 * 1024, + metadata={"help": _("The max token limit. Default is 100k")}, + ) + + +@dataclass +class GPTsAppCommonConfig(BaseParameters, RegisterParameters): + __type_field__ = "name" + top_k: Optional[int] = field( + default=None, + metadata={"help": _("The top k for LLM generation")}, + ) + top_p: Optional[float] = field( + default=None, + metadata={"help": _("The top p for LLM generation")}, + ) + temperature: Optional[float] = field( + default=None, + metadata={"help": _("The temperature for LLM generation")}, + ) + max_new_tokens: Optional[int] = field( + default=None, + metadata={"help": _("The max new tokens for LLM generation")}, + ) + name: Optional[str] = field( + default=None, metadata={"help": _("The name of your app")} + ) + memory: Optional[BaseGPTsAppMemoryConfig] = field( + default=None, metadata={"help": _("The memory configuration")} + ) + + +@dataclass +class GPTsAppConfig(GPTsAppCommonConfig, BaseParameters): + """GPTs application configuration. + + For global configuration, you can set the parameters here. + """ + + name: str = "app" + configs: List[GPTsAppCommonConfig] = field( + default_factory=list, + metadata={"help": _("The configs for specific app")}, + ) + + +@cache +def parse_config( + system_app: SystemApp, + config_name: str, + type_class: Optional[Type[GPTsAppCommonConfig]], +): + from dbgpt_app.config import ApplicationConfig + + app_config = system_app.config.get_typed("app_config", ApplicationConfig) + # Global config for the chat scene + config = app_config.app + + for custom_config in config.configs: + if custom_config.name == config_name: + return custom_config + + if type_class is not None: + return type_class.from_dict(config.to_dict(), ignore_extra_fields=True) + + return config diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/api/endpoints.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/api/endpoints.py index 41dfe88d3..2981a42f8 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/api/endpoints.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/api/endpoints.py @@ -114,7 +114,8 @@ async def create( Returns: ServerResponse: The response """ - return Result.succ(service.create(request)) + res = await blocking_func_to_async(global_system_app, service.create, request) + return Result.succ(res) @router.put( @@ -134,7 +135,8 @@ async def update( Returns: ServerResponse: The response """ - return Result.succ(service.update(request)) + res = await blocking_func_to_async(global_system_app, service.update, request) + return Result.succ(res) @router.delete( @@ -153,7 +155,7 @@ async def delete( Returns: ServerResponse: The response """ - service.delete(datasource_id) + await blocking_func_to_async(global_system_app, service.delete, datasource_id) return Result.succ(None) @@ -173,7 +175,8 @@ async def query( Returns: List[ServeResponse]: The response """ - return Result.succ(service.get(datasource_id)) + res = await blocking_func_to_async(global_system_app, service.get, datasource_id) + return Result.succ(res) @router.get( @@ -194,7 +197,9 @@ async def query_page( Returns: ServerResponse: The response """ - res = service.get_list(db_type=db_type) + res = await blocking_func_to_async( + global_system_app, service.get_list, db_type=db_type + ) return Result.succ(res) @@ -207,7 +212,8 @@ async def get_datasource_types( service: Service = Depends(get_service), ) -> Result[ResourceTypes]: """Get the datasource types.""" - return Result.succ(service.datasource_types()) + res = await blocking_func_to_async(global_system_app, service.datasource_types) + return Result.succ(res) @router.post( diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py index 1ec6b44e7..10de6b012 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py @@ -233,12 +233,8 @@ class Service( DatasourceServeResponse: The data after deletion """ db_config = self._dao.get_one({"id": datasource_id}) - vector_name = db_config.db_name + "_profile" - vector_connector = self.storage_manager.create_vector_store( - index_name=vector_name - ) - vector_connector.delete_vector_name(vector_name) if db_config: + self._db_summary_client.delete_db_profile(db_config.db_name) self._dao.delete({"id": datasource_id}) return db_config @@ -301,12 +297,18 @@ class Service( bool: The refresh result """ db_config = self._dao.get_one({"id": datasource_id}) - vector_name = db_config.db_name + "_profile" - vector_connector = self.storage_manager.create_vector_store( - index_name=vector_name - ) - vector_connector.delete_vector_name(vector_name) - self._db_summary_client.db_summary_embedding( - db_config.db_name, db_config.db_type + if not db_config: + raise HTTPException(status_code=404, detail="datasource not found") + + self._db_summary_client.delete_db_profile(db_config.db_name) + + # async embedding + executor = self._system_app.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() # type: ignore + executor.submit( + self._db_summary_client.db_summary_embedding, + db_config.db_name, + db_config.db_type, ) return True diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/api/endpoints.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/api/endpoints.py index 77cf4042f..60d93cecc 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/api/endpoints.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/api/endpoints.py @@ -14,7 +14,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.util import PaginationResult from dbgpt_ext.rag.chunk_manager import ChunkParameters -from dbgpt_serve.core import Result +from dbgpt_serve.core import Result, blocking_func_to_async from dbgpt_serve.rag.api.schemas import ( DocumentServeRequest, DocumentServeResponse, @@ -155,7 +155,9 @@ async def delete( Returns: ServerResponse: The response """ - return Result.succ(service.delete(space_id)) + # TODO: Delete the files in the space + res = await blocking_func_to_async(global_system_app, service.delete, space_id) + return Result.succ(res) @router.get( @@ -248,7 +250,10 @@ async def create_document( doc_file=doc_file, space_id=space_id, ) - return Result.succ(await service.create_document(request)) + res = await blocking_func_to_async( + global_system_app, service.create_document, request + ) + return Result.succ(res) @router.get( @@ -369,7 +374,11 @@ async def delete_document( Returns: ServerResponse: The response """ - return Result.succ(service.delete_document(document_id)) + # TODO: Delete the files of the document + res = await blocking_func_to_async( + global_system_app, service.delete_document, document_id + ) + return Result.succ(res) def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None: diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py index 8cd1475ba..75dfc4dc3 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py @@ -2,8 +2,6 @@ import asyncio import json import logging import os -import shutil -import tempfile from datetime import datetime from enum import Enum from typing import List, Optional, cast @@ -13,9 +11,10 @@ from fastapi import HTTPException from dbgpt.component import ComponentType, SystemApp from dbgpt.configs import TAG_KEY_KNOWLEDGE_FACTORY_DOMAIN_TYPE from dbgpt.configs.model_config import ( - KNOWLEDGE_UPLOAD_ROOT_PATH, + KNOWLEDGE_CACHE_ROOT_PATH, ) from dbgpt.core import Chunk, LLMClient +from dbgpt.core.interface.file import _SCHEMA, FileStorageClient from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory @@ -30,7 +29,7 @@ from dbgpt_app.knowledge.request.request import BusinessFieldType from dbgpt_ext.rag.assembler import EmbeddingAssembler from dbgpt_ext.rag.chunk_manager import ChunkParameters from dbgpt_ext.rag.knowledge import KnowledgeFactory -from dbgpt_serve.core import BaseService +from dbgpt_serve.core import BaseService, blocking_func_to_async from ..api.schemas import ( ChunkServeRequest, @@ -117,6 +116,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes ).create() return DefaultLLMClient(worker_manager, True) + def get_fs(self) -> FileStorageClient: + """Get the FileStorageClient instance""" + return FileStorageClient.get_instance(self.system_app) + def create_space(self, request: SpaceServeRequest) -> SpaceServeResponse: """Create a new Space entity @@ -155,7 +158,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes update_obj = self._dao.update_knowledge_space(self._dao.from_request(request)) return update_obj - async def create_document(self, request: DocumentServeRequest) -> str: + def create_document(self, request: DocumentServeRequest) -> str: """Create a new document entity Args: @@ -173,20 +176,21 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes raise Exception(f"document name:{request.doc_name} have already named") if request.doc_file and request.doc_type == KnowledgeType.DOCUMENT.name: doc_file = request.doc_file - if not os.path.exists(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)): - os.makedirs(os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name)) - tmp_fd, tmp_path = tempfile.mkstemp( - dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name) - ) - with os.fdopen(tmp_fd, "wb") as tmp: - tmp.write(await request.doc_file.read()) - shutil.move( - tmp_path, - os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename), - ) - request.content = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, space.name, doc_file.filename + safe_filename = os.path.basename(doc_file.filename) + custom_metadata = { + "space_name": space.name, + "doc_name": doc_file.filename, + "doc_type": request.doc_type, + } + bucket = "dbgpt_knowledge_file" + file_uri = self.get_fs().save_file( + bucket, + safe_filename, + doc_file.file, + storage_type="distributed", + custom_metadata=custom_metadata, ) + request.content = file_uri document = KnowledgeDocumentEntity( doc_name=request.doc_name, doc_type=request.doc_type, @@ -490,28 +494,58 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes storage_connector = self.storage_manager.get_storage_connector( space.name, space.vector_type ) + knowledge_content = doc.content + if ( + doc.doc_type == KnowledgeType.DOCUMENT.value + and knowledge_content.startswith(_SCHEMA) + ): + logger.info( + f"Download file from file storage, doc: {doc.doc_name}, file url: " + f"{doc.content}" + ) + local_file_path, file_meta = await blocking_func_to_async( + self.system_app, + self.get_fs().download_file, + knowledge_content, + dest_dir=KNOWLEDGE_CACHE_ROOT_PATH, + ) + logger.info(f"Downloaded file to {local_file_path}") + knowledge_content = local_file_path knowledge = None if not space.domain_type or ( space.domain_type.lower() == BusinessFieldType.NORMAL.value.lower() ): knowledge = KnowledgeFactory.create( - datasource=doc.content, + datasource=knowledge_content, knowledge_type=KnowledgeType.get_by_value(doc.doc_type), ) doc.status = SyncStatus.RUNNING.name doc.gmt_modified = datetime.now() - self._document_dao.update_knowledge_document(doc) + await blocking_func_to_async( + self.system_app, self._document_dao.update_knowledge_document, doc + ) asyncio.create_task( self.async_doc_process( - knowledge, chunk_parameters, storage_connector, doc, space + knowledge, + chunk_parameters, + storage_connector, + doc, + space, + knowledge_content, ) ) logger.info(f"begin save document chunks, doc:{doc.doc_name}") @trace("async_doc_process") async def async_doc_process( - self, knowledge, chunk_parameters, storage_connector, doc, space + self, + knowledge, + chunk_parameters, + storage_connector, + doc, + space, + knowledge_content: str, ): """async document process into storage Args: @@ -539,7 +573,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes f" and value: {space.domain_type}, dag: {dags[0]}" ) db_name, chunk_docs = await end_task.call( - {"file_path": doc.content, "space": doc.space} + {"file_path": knowledge_content, "space": doc.space} ) doc.chunk_size = len(chunk_docs) vector_ids = [chunk.chunk_id for chunk in chunk_docs] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/tests/test_service.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/tests/test_service.py index ec9345cef..5fa9149cd 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/tests/test_service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/tests/test_service.py @@ -115,7 +115,7 @@ async def test_create_document(service): service._document_dao.get_knowledge_documents = Mock(return_value=[]) service._document_dao.create_knowledge_document = Mock(return_value="2") - response = await service.create_document(request) + response = service.create_document(request) assert response == "2"