mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 04:36:23 +00:00
feat: Support rag flow template
This commit is contained in:
parent
08749a0110
commit
f1e00a7502
0
dbgpt/serve/dbgpts/__init__.py
Normal file
0
dbgpt/serve/dbgpts/__init__.py
Normal file
@ -14,7 +14,7 @@ from dbgpt.serve.core import Result, blocking_func_to_async
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||
from ..service.service import Service
|
||||
from ..service.service import Service, _parse_flow_template_from_json
|
||||
from ..service.variables_service import VariablesService
|
||||
from .schemas import (
|
||||
FlowDebugRequest,
|
||||
@ -512,7 +512,7 @@ async def import_flow(
|
||||
raise HTTPException(
|
||||
status_code=400, detail="invalid json file, missing 'flow' key"
|
||||
)
|
||||
flow = ServeRequest.parse_obj(json_dict["flow"])
|
||||
flow = _parse_flow_template_from_json(json_dict["flow"])
|
||||
elif file_extension == "zip":
|
||||
from ..service.share_utils import _parse_flow_from_zip_file
|
||||
|
||||
@ -531,6 +531,31 @@ async def import_flow(
|
||||
return Result.succ(flow)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/flow/templates",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_flow_templates(
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[ServerResponse]]:
|
||||
"""Query Flow templates."""
|
||||
|
||||
res = await blocking_func_to_async(
|
||||
global_system_app,
|
||||
service.get_flow_templates,
|
||||
user_name,
|
||||
sys_code,
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
return Result.succ(res)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
from .variables_provider import (
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator, List, Optional, cast
|
||||
|
||||
import schedule
|
||||
@ -399,6 +400,47 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
item.metadata = metadata.to_dict()
|
||||
return page_result
|
||||
|
||||
def get_flow_templates(
|
||||
self,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> PaginationResult[ServerResponse]:
|
||||
"""Get a list of Flow templates
|
||||
|
||||
Args:
|
||||
user_name (Optional[str]): The user name
|
||||
sys_code (Optional[str]): The system code
|
||||
page (int): The page number
|
||||
page_size (int): The page size
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
local_file_templates = self._get_flow_templates_from_files()
|
||||
return PaginationResult.build_from_all(local_file_templates, page, page_size)
|
||||
|
||||
def _get_flow_templates_from_files(self) -> List[ServerResponse]:
|
||||
"""Get a list of Flow templates from files"""
|
||||
user_lang = self._system_app.config.get_current_lang(default="en")
|
||||
# List files in current directory
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
template_dir = os.path.join(parent_dir, "templates", user_lang)
|
||||
default_template_dir = os.path.join(parent_dir, "templates", "en")
|
||||
if not os.path.exists(template_dir):
|
||||
template_dir = default_template_dir
|
||||
templates = []
|
||||
for root, _, files in os.walk(template_dir):
|
||||
for file in files:
|
||||
if file.endswith(".json"):
|
||||
try:
|
||||
with open(os.path.join(root, file), "r") as f:
|
||||
data = json.load(f)
|
||||
templates.append(_parse_flow_template_from_json(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Load template {file} error: {str(e)}")
|
||||
return templates
|
||||
|
||||
async def chat_stream_flow_str(
|
||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||
) -> AsyncIterator[str]:
|
||||
@ -638,3 +680,20 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
break
|
||||
else:
|
||||
yield f"data:{text}\n\n"
|
||||
|
||||
|
||||
def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse:
|
||||
"""Parse the flow from json
|
||||
|
||||
Args:
|
||||
json_dict (dict): The json dict
|
||||
|
||||
Returns:
|
||||
ServerResponse: The flow
|
||||
"""
|
||||
flow_json = json_dict["flow"]
|
||||
flow_json["editable"] = False
|
||||
del flow_json["uid"]
|
||||
flow_json["state"] = State.INITIALIZING
|
||||
flow_json["dag_id"] = None
|
||||
return ServerResponse(**flow_json)
|
||||
|
1088
dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json
Normal file
1088
dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user