feat(base): base dao in query support

This commit is contained in:
yhjun1026
2024-08-20 15:39:23 +08:00
parent cd82feebc9
commit 7e00ea940a
5 changed files with 69 additions and 16 deletions

View File

@@ -19,6 +19,15 @@ def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
return isinstance(output_obj, chat_types) return isinstance(output_obj, chat_types)
def is_agent_flow_type(output_obj: Any, is_class: bool = False) -> bool:
"""Check whether the output object is a agent flow type."""
if is_class:
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
else:
chat_types = (str, CommonLLMHttpResponseBody)
return isinstance(output_obj, chat_types)
async def safe_chat_with_dag_task( async def safe_chat_with_dag_task(
task: BaseOperator, request: Any, covert_to_str: bool = False task: BaseOperator, request: Any, covert_to_str: bool = False
) -> ModelOutput: ) -> ModelOutput:

View File

@@ -33,6 +33,7 @@ from dbgpt.app.dbgpt_server import system_app
from dbgpt.app.scene.base import ChatScene from dbgpt.app.scene.base import ChatScene
from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import PromptTemplate from dbgpt.core import PromptTemplate
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.core.interface.message import StorageConversation from dbgpt.core.interface.message import StorageConversation
from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient from dbgpt.model.cluster.client import DefaultLLMClient
@@ -222,14 +223,11 @@ class MultiAgents(BaseComponent, ABC):
) )
) )
is_agent_flow = True if (
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode: TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode
from dbgpt.agent import AWELTeamContext and gpt_app.team_context.flow_category == FlowCategory.CHAT_FLOW
):
team_context: AWELTeamContext = gpt_app.team_context team_context = gpt_app.team_context
if team_context.flow_category == "chat_flow":
is_agent_flow = False
if not is_agent_flow:
from dbgpt.core.awel import CommonLLMHttpRequestBody from dbgpt.core.awel import CommonLLMHttpRequestBody
flow_req = CommonLLMHttpRequestBody( flow_req = CommonLLMHttpRequestBody(

View File

@@ -6,6 +6,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.serve.core import Result from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult from dbgpt.util import PaginationResult
@@ -172,6 +173,35 @@ async def get_flows(
return Result.succ(flow) return Result.succ(flow)
@router.get(
"/chat/flows",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_chat_flows(
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
name: Optional[str] = Query(default=None, description="flow name"),
uid: Optional[str] = Query(default=None, description="flow uid"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
return Result.succ(
service.get_list_by_page(
{
"user_name": user_name,
"sys_code": sys_code,
"name": name,
"uid": uid,
"flow_category": [FlowCategory.CHAT_AGENT, FlowCategory.CHAT_FLOW],
},
page,
page_size,
)
)
@router.get( @router.get(
"/flows", "/flows",
response_model=Result[PaginationResult[ServerResponse]], response_model=Result[PaginationResult[ServerResponse]],

View File

@@ -6,6 +6,7 @@ import schedule
from fastapi import HTTPException from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.dag.dag_manager import DAGManager
@@ -548,18 +549,28 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
or not isinstance(leaf_nodes[0], BaseOperator) or not isinstance(leaf_nodes[0], BaseOperator)
): ):
return FlowCategory.COMMON return FlowCategory.COMMON
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs:
return FlowCategory.COMMON
common_http_trigger = False common_http_trigger = False
agent_trigger = False
for trigger in triggers: for trigger in triggers:
if isinstance(trigger, CommonLLMHttpTrigger): if isinstance(trigger, CommonLLMHttpTrigger):
common_http_trigger = True common_http_trigger = True
break break
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs: if isinstance(trigger, AgentDummyTrigger):
return FlowCategory.COMMON agent_trigger = True
break
output = leaf_node.metadata.outputs[0] output = leaf_node.metadata.outputs[0]
try: try:
real_class = _get_type_cls(output.type_cls) real_class = _get_type_cls(output.type_cls)
if common_http_trigger and is_chat_flow_type(real_class, is_class=True): if agent_trigger:
return FlowCategory.CHAT_AGENT
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW return FlowCategory.CHAT_FLOW
except Exception: except Exception:
return FlowCategory.COMMON return FlowCategory.COMMON

View File

@@ -16,7 +16,6 @@ REQ = TypeVar("REQ")
# The response schema type # The response schema type
RES = TypeVar("RES") RES = TypeVar("RES")
QUERY_SPEC = Union[REQ, Dict[str, Any]] QUERY_SPEC = Union[REQ, Dict[str, Any]]
@@ -286,10 +285,16 @@ class BaseDao(Generic[T, REQ, RES]):
else model_to_dict(query_request) else model_to_dict(query_request)
) )
for key, value in query_dict.items(): for key, value in query_dict.items():
if value is not None: if value is not None and hasattr(model_cls, key):
if isinstance(value, (list, tuple, dict, set)): if isinstance(value, list):
if len(value) > 0:
query = query.filter(getattr(model_cls, key).in_(value))
else:
continue
elif isinstance(value, (tuple, dict, set)):
continue continue
query = query.filter(getattr(model_cls, key) == value) else:
query = query.filter(getattr(model_cls, key) == value)
if desc_order_column: if desc_order_column:
query = query.order_by(desc(getattr(model_cls, desc_order_column))) query = query.order_by(desc(getattr(model_cls, desc_order_column)))