mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(base): base dao in query support
This commit is contained in:
@@ -19,6 +19,15 @@ def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
|||||||
return isinstance(output_obj, chat_types)
|
return isinstance(output_obj, chat_types)
|
||||||
|
|
||||||
|
|
||||||
|
def is_agent_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||||
|
"""Check whether the output object is a agent flow type."""
|
||||||
|
if is_class:
|
||||||
|
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
|
||||||
|
else:
|
||||||
|
chat_types = (str, CommonLLMHttpResponseBody)
|
||||||
|
return isinstance(output_obj, chat_types)
|
||||||
|
|
||||||
|
|
||||||
async def safe_chat_with_dag_task(
|
async def safe_chat_with_dag_task(
|
||||||
task: BaseOperator, request: Any, covert_to_str: bool = False
|
task: BaseOperator, request: Any, covert_to_str: bool = False
|
||||||
) -> ModelOutput:
|
) -> ModelOutput:
|
||||||
|
@@ -33,6 +33,7 @@ from dbgpt.app.dbgpt_server import system_app
|
|||||||
from dbgpt.app.scene.base import ChatScene
|
from dbgpt.app.scene.base import ChatScene
|
||||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||||
from dbgpt.core import PromptTemplate
|
from dbgpt.core import PromptTemplate
|
||||||
|
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||||
from dbgpt.core.interface.message import StorageConversation
|
from dbgpt.core.interface.message import StorageConversation
|
||||||
from dbgpt.model.cluster import WorkerManagerFactory
|
from dbgpt.model.cluster import WorkerManagerFactory
|
||||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||||
@@ -222,14 +223,11 @@ class MultiAgents(BaseComponent, ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
is_agent_flow = True
|
if (
|
||||||
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode:
|
TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode
|
||||||
from dbgpt.agent import AWELTeamContext
|
and gpt_app.team_context.flow_category == FlowCategory.CHAT_FLOW
|
||||||
|
):
|
||||||
team_context: AWELTeamContext = gpt_app.team_context
|
team_context = gpt_app.team_context
|
||||||
if team_context.flow_category == "chat_flow":
|
|
||||||
is_agent_flow = False
|
|
||||||
if not is_agent_flow:
|
|
||||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||||
|
|
||||||
flow_req = CommonLLMHttpRequestBody(
|
flow_req = CommonLLMHttpRequestBody(
|
||||||
|
@@ -6,6 +6,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
|||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||||
|
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||||
from dbgpt.serve.core import Result
|
from dbgpt.serve.core import Result
|
||||||
from dbgpt.util import PaginationResult
|
from dbgpt.util import PaginationResult
|
||||||
|
|
||||||
@@ -172,6 +173,35 @@ async def get_flows(
|
|||||||
return Result.succ(flow)
|
return Result.succ(flow)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/chat/flows",
|
||||||
|
response_model=Result[PaginationResult[ServerResponse]],
|
||||||
|
dependencies=[Depends(check_api_key)],
|
||||||
|
)
|
||||||
|
async def query_chat_flows(
|
||||||
|
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||||
|
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||||
|
page: int = Query(default=1, description="current page"),
|
||||||
|
page_size: int = Query(default=20, description="page size"),
|
||||||
|
name: Optional[str] = Query(default=None, description="flow name"),
|
||||||
|
uid: Optional[str] = Query(default=None, description="flow uid"),
|
||||||
|
service: Service = Depends(get_service),
|
||||||
|
) -> Result[PaginationResult[ServerResponse]]:
|
||||||
|
return Result.succ(
|
||||||
|
service.get_list_by_page(
|
||||||
|
{
|
||||||
|
"user_name": user_name,
|
||||||
|
"sys_code": sys_code,
|
||||||
|
"name": name,
|
||||||
|
"uid": uid,
|
||||||
|
"flow_category": [FlowCategory.CHAT_AGENT, FlowCategory.CHAT_FLOW],
|
||||||
|
},
|
||||||
|
page,
|
||||||
|
page_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/flows",
|
"/flows",
|
||||||
response_model=Result[PaginationResult[ServerResponse]],
|
response_model=Result[PaginationResult[ServerResponse]],
|
||||||
|
@@ -6,6 +6,7 @@ import schedule
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from dbgpt._private.pydantic import model_to_json
|
from dbgpt._private.pydantic import model_to_json
|
||||||
|
from dbgpt.agent import AgentDummyTrigger
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||||
@@ -548,18 +549,28 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
or not isinstance(leaf_nodes[0], BaseOperator)
|
or not isinstance(leaf_nodes[0], BaseOperator)
|
||||||
):
|
):
|
||||||
return FlowCategory.COMMON
|
return FlowCategory.COMMON
|
||||||
|
|
||||||
|
leaf_node = cast(BaseOperator, leaf_nodes[0])
|
||||||
|
if not leaf_node.metadata or not leaf_node.metadata.outputs:
|
||||||
|
return FlowCategory.COMMON
|
||||||
|
|
||||||
common_http_trigger = False
|
common_http_trigger = False
|
||||||
|
agent_trigger = False
|
||||||
for trigger in triggers:
|
for trigger in triggers:
|
||||||
if isinstance(trigger, CommonLLMHttpTrigger):
|
if isinstance(trigger, CommonLLMHttpTrigger):
|
||||||
common_http_trigger = True
|
common_http_trigger = True
|
||||||
break
|
break
|
||||||
leaf_node = cast(BaseOperator, leaf_nodes[0])
|
|
||||||
if not leaf_node.metadata or not leaf_node.metadata.outputs:
|
if isinstance(trigger, AgentDummyTrigger):
|
||||||
return FlowCategory.COMMON
|
agent_trigger = True
|
||||||
|
break
|
||||||
|
|
||||||
output = leaf_node.metadata.outputs[0]
|
output = leaf_node.metadata.outputs[0]
|
||||||
try:
|
try:
|
||||||
real_class = _get_type_cls(output.type_cls)
|
real_class = _get_type_cls(output.type_cls)
|
||||||
if common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
if agent_trigger:
|
||||||
|
return FlowCategory.CHAT_AGENT
|
||||||
|
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
||||||
return FlowCategory.CHAT_FLOW
|
return FlowCategory.CHAT_FLOW
|
||||||
except Exception:
|
except Exception:
|
||||||
return FlowCategory.COMMON
|
return FlowCategory.COMMON
|
||||||
|
@@ -16,7 +16,6 @@ REQ = TypeVar("REQ")
|
|||||||
# The response schema type
|
# The response schema type
|
||||||
RES = TypeVar("RES")
|
RES = TypeVar("RES")
|
||||||
|
|
||||||
|
|
||||||
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
@@ -286,10 +285,16 @@ class BaseDao(Generic[T, REQ, RES]):
|
|||||||
else model_to_dict(query_request)
|
else model_to_dict(query_request)
|
||||||
)
|
)
|
||||||
for key, value in query_dict.items():
|
for key, value in query_dict.items():
|
||||||
if value is not None:
|
if value is not None and hasattr(model_cls, key):
|
||||||
if isinstance(value, (list, tuple, dict, set)):
|
if isinstance(value, list):
|
||||||
|
if len(value) > 0:
|
||||||
|
query = query.filter(getattr(model_cls, key).in_(value))
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
elif isinstance(value, (tuple, dict, set)):
|
||||||
continue
|
continue
|
||||||
query = query.filter(getattr(model_cls, key) == value)
|
else:
|
||||||
|
query = query.filter(getattr(model_cls, key) == value)
|
||||||
|
|
||||||
if desc_order_column:
|
if desc_order_column:
|
||||||
query = query.order_by(desc(getattr(model_cls, desc_order_column)))
|
query = query.order_by(desc(getattr(model_cls, desc_order_column)))
|
||||||
|
Reference in New Issue
Block a user