mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
feat(base): base dao in query support
This commit is contained in:
@@ -19,6 +19,15 @@ def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
def is_agent_flow_type(output_obj: Any, is_class: bool = False) -> bool:
|
||||
"""Check whether the output object is a agent flow type."""
|
||||
if is_class:
|
||||
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
|
||||
else:
|
||||
chat_types = (str, CommonLLMHttpResponseBody)
|
||||
return isinstance(output_obj, chat_types)
|
||||
|
||||
|
||||
async def safe_chat_with_dag_task(
|
||||
task: BaseOperator, request: Any, covert_to_str: bool = False
|
||||
) -> ModelOutput:
|
||||
|
@@ -33,6 +33,7 @@ from dbgpt.app.dbgpt_server import system_app
|
||||
from dbgpt.app.scene.base import ChatScene
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import PromptTemplate
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||
from dbgpt.core.interface.message import StorageConversation
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.client import DefaultLLMClient
|
||||
@@ -222,14 +223,11 @@ class MultiAgents(BaseComponent, ABC):
|
||||
)
|
||||
)
|
||||
|
||||
is_agent_flow = True
|
||||
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode:
|
||||
from dbgpt.agent import AWELTeamContext
|
||||
|
||||
team_context: AWELTeamContext = gpt_app.team_context
|
||||
if team_context.flow_category == "chat_flow":
|
||||
is_agent_flow = False
|
||||
if not is_agent_flow:
|
||||
if (
|
||||
TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode
|
||||
and gpt_app.team_context.flow_category == FlowCategory.CHAT_FLOW
|
||||
):
|
||||
team_context = gpt_app.team_context
|
||||
from dbgpt.core.awel import CommonLLMHttpRequestBody
|
||||
|
||||
flow_req = CommonLLMHttpRequestBody(
|
||||
|
@@ -6,6 +6,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowCategory
|
||||
from dbgpt.serve.core import Result
|
||||
from dbgpt.util import PaginationResult
|
||||
|
||||
@@ -172,6 +173,35 @@ async def get_flows(
|
||||
return Result.succ(flow)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/chat/flows",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def query_chat_flows(
|
||||
user_name: Optional[str] = Query(default=None, description="user name"),
|
||||
sys_code: Optional[str] = Query(default=None, description="system code"),
|
||||
page: int = Query(default=1, description="current page"),
|
||||
page_size: int = Query(default=20, description="page size"),
|
||||
name: Optional[str] = Query(default=None, description="flow name"),
|
||||
uid: Optional[str] = Query(default=None, description="flow uid"),
|
||||
service: Service = Depends(get_service),
|
||||
) -> Result[PaginationResult[ServerResponse]]:
|
||||
return Result.succ(
|
||||
service.get_list_by_page(
|
||||
{
|
||||
"user_name": user_name,
|
||||
"sys_code": sys_code,
|
||||
"name": name,
|
||||
"uid": uid,
|
||||
"flow_category": [FlowCategory.CHAT_AGENT, FlowCategory.CHAT_FLOW],
|
||||
},
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/flows",
|
||||
response_model=Result[PaginationResult[ServerResponse]],
|
||||
|
@@ -6,6 +6,7 @@ import schedule
|
||||
from fastapi import HTTPException
|
||||
|
||||
from dbgpt._private.pydantic import model_to_json
|
||||
from dbgpt.agent import AgentDummyTrigger
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
@@ -548,18 +549,28 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
or not isinstance(leaf_nodes[0], BaseOperator)
|
||||
):
|
||||
return FlowCategory.COMMON
|
||||
|
||||
leaf_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not leaf_node.metadata or not leaf_node.metadata.outputs:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
common_http_trigger = False
|
||||
agent_trigger = False
|
||||
for trigger in triggers:
|
||||
if isinstance(trigger, CommonLLMHttpTrigger):
|
||||
common_http_trigger = True
|
||||
break
|
||||
leaf_node = cast(BaseOperator, leaf_nodes[0])
|
||||
if not leaf_node.metadata or not leaf_node.metadata.outputs:
|
||||
return FlowCategory.COMMON
|
||||
|
||||
if isinstance(trigger, AgentDummyTrigger):
|
||||
agent_trigger = True
|
||||
break
|
||||
|
||||
output = leaf_node.metadata.outputs[0]
|
||||
try:
|
||||
real_class = _get_type_cls(output.type_cls)
|
||||
if common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
||||
if agent_trigger:
|
||||
return FlowCategory.CHAT_AGENT
|
||||
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
|
||||
return FlowCategory.CHAT_FLOW
|
||||
except Exception:
|
||||
return FlowCategory.COMMON
|
||||
|
@@ -16,7 +16,6 @@ REQ = TypeVar("REQ")
|
||||
# The response schema type
|
||||
RES = TypeVar("RES")
|
||||
|
||||
|
||||
QUERY_SPEC = Union[REQ, Dict[str, Any]]
|
||||
|
||||
|
||||
@@ -286,10 +285,16 @@ class BaseDao(Generic[T, REQ, RES]):
|
||||
else model_to_dict(query_request)
|
||||
)
|
||||
for key, value in query_dict.items():
|
||||
if value is not None:
|
||||
if isinstance(value, (list, tuple, dict, set)):
|
||||
if value is not None and hasattr(model_cls, key):
|
||||
if isinstance(value, list):
|
||||
if len(value) > 0:
|
||||
query = query.filter(getattr(model_cls, key).in_(value))
|
||||
else:
|
||||
continue
|
||||
elif isinstance(value, (tuple, dict, set)):
|
||||
continue
|
||||
query = query.filter(getattr(model_cls, key) == value)
|
||||
else:
|
||||
query = query.filter(getattr(model_cls, key) == value)
|
||||
|
||||
if desc_order_column:
|
||||
query = query.order_by(desc(getattr(model_cls, desc_order_column)))
|
||||
|
Reference in New Issue
Block a user