mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 23:01:38 +00:00
feat: Support rag flow template
This commit is contained in:
parent
08749a0110
commit
f1e00a7502
0
dbgpt/serve/dbgpts/__init__.py
Normal file
0
dbgpt/serve/dbgpts/__init__.py
Normal file
@ -14,7 +14,7 @@ from dbgpt.serve.core import Result, blocking_func_to_async
|
|||||||
from dbgpt.util import PaginationResult
|
from dbgpt.util import PaginationResult
|
||||||
|
|
||||||
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
|
||||||
from ..service.service import Service
|
from ..service.service import Service, _parse_flow_template_from_json
|
||||||
from ..service.variables_service import VariablesService
|
from ..service.variables_service import VariablesService
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
FlowDebugRequest,
|
FlowDebugRequest,
|
||||||
@ -512,7 +512,7 @@ async def import_flow(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="invalid json file, missing 'flow' key"
|
status_code=400, detail="invalid json file, missing 'flow' key"
|
||||||
)
|
)
|
||||||
flow = ServeRequest.parse_obj(json_dict["flow"])
|
flow = _parse_flow_template_from_json(json_dict["flow"])
|
||||||
elif file_extension == "zip":
|
elif file_extension == "zip":
|
||||||
from ..service.share_utils import _parse_flow_from_zip_file
|
from ..service.share_utils import _parse_flow_from_zip_file
|
||||||
|
|
||||||
@ -531,6 +531,31 @@ async def import_flow(
|
|||||||
return Result.succ(flow)
|
return Result.succ(flow)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/flow/templates",
|
||||||
|
response_model=Result[PaginationResult[ServerResponse]],
|
||||||
|
dependencies=[Depends(check_api_key)],
|
||||||
|
)
|
||||||
|
async def query_flow_templates(
|
||||||
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||||
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||||
|
page: int = Query(default=1, description="current page"),
|
||||||
|
page_size: int = Query(default=20, description="page size"),
|
||||||
|
service: Service = Depends(get_service),
|
||||||
|
) -> Result[PaginationResult[ServerResponse]]:
|
||||||
|
"""Query Flow templates."""
|
||||||
|
|
||||||
|
res = await blocking_func_to_async(
|
||||||
|
global_system_app,
|
||||||
|
service.get_flow_templates,
|
||||||
|
user_name,
|
||||||
|
sys_code,
|
||||||
|
page,
|
||||||
|
page_size,
|
||||||
|
)
|
||||||
|
return Result.succ(res)
|
||||||
|
|
||||||
|
|
||||||
def init_endpoints(system_app: SystemApp) -> None:
|
def init_endpoints(system_app: SystemApp) -> None:
|
||||||
"""Initialize the endpoints"""
|
"""Initialize the endpoints"""
|
||||||
from .variables_provider import (
|
from .variables_provider import (
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import AsyncIterator, List, Optional, cast
|
from typing import AsyncIterator, List, Optional, cast
|
||||||
|
|
||||||
import schedule
|
import schedule
|
||||||
@ -399,6 +400,47 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
item.metadata = metadata.to_dict()
|
item.metadata = metadata.to_dict()
|
||||||
return page_result
|
return page_result
|
||||||
|
|
||||||
|
def get_flow_templates(
|
||||||
|
self,
|
||||||
|
user_name: Optional[str] = None,
|
||||||
|
sys_code: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
) -> PaginationResult[ServerResponse]:
|
||||||
|
"""Get a list of Flow templates
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_name (Optional[str]): The user name
|
||||||
|
sys_code (Optional[str]): The system code
|
||||||
|
page (int): The page number
|
||||||
|
page_size (int): The page size
|
||||||
|
Returns:
|
||||||
|
List[ServerResponse]: The response
|
||||||
|
"""
|
||||||
|
local_file_templates = self._get_flow_templates_from_files()
|
||||||
|
return PaginationResult.build_from_all(local_file_templates, page, page_size)
|
||||||
|
|
||||||
|
def _get_flow_templates_from_files(self) -> List[ServerResponse]:
|
||||||
|
"""Get a list of Flow templates from files"""
|
||||||
|
user_lang = self._system_app.config.get_current_lang(default="en")
|
||||||
|
# List files in current directory
|
||||||
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
template_dir = os.path.join(parent_dir, "templates", user_lang)
|
||||||
|
default_template_dir = os.path.join(parent_dir, "templates", "en")
|
||||||
|
if not os.path.exists(template_dir):
|
||||||
|
template_dir = default_template_dir
|
||||||
|
templates = []
|
||||||
|
for root, _, files in os.walk(template_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".json"):
|
||||||
|
try:
|
||||||
|
with open(os.path.join(root, file), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
templates.append(_parse_flow_template_from_json(data))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Load template {file} error: {str(e)}")
|
||||||
|
return templates
|
||||||
|
|
||||||
async def chat_stream_flow_str(
|
async def chat_stream_flow_str(
|
||||||
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
self, flow_uid: str, request: CommonLLMHttpRequestBody
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
@ -638,3 +680,20 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
yield f"data:{text}\n\n"
|
yield f"data:{text}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_flow_template_from_json(json_dict: dict) -> ServerResponse:
|
||||||
|
"""Parse the flow from json
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_dict (dict): The json dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ServerResponse: The flow
|
||||||
|
"""
|
||||||
|
flow_json = json_dict["flow"]
|
||||||
|
flow_json["editable"] = False
|
||||||
|
del flow_json["uid"]
|
||||||
|
flow_json["state"] = State.INITIALIZING
|
||||||
|
flow_json["dag_id"] = None
|
||||||
|
return ServerResponse(**flow_json)
|
||||||
|
1088
dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json
Normal file
1088
dbgpt/serve/flow/templates/en/rag-chat-awel-flow-template.json
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user