feat(base): base dao in query support

This commit is contained in:
yhjun1026
2024-08-20 15:39:23 +08:00
parent cd82feebc9
commit 7e00ea940a
5 changed files with 69 additions and 16 deletions

View File

@@ -19,6 +19,15 @@ def is_chat_flow_type(output_obj: Any, is_class: bool = False) -> bool:
return isinstance(output_obj, chat_types)
def is_agent_flow_type(output_obj: Any, is_class: bool = False) -> bool:
"""Check whether the output object is a agent flow type."""
if is_class:
return output_obj in (str, CommonLLMHttpResponseBody, ModelOutput)
else:
chat_types = (str, CommonLLMHttpResponseBody)
return isinstance(output_obj, chat_types)
async def safe_chat_with_dag_task(
task: BaseOperator, request: Any, covert_to_str: bool = False
) -> ModelOutput:

View File

@@ -33,6 +33,7 @@ from dbgpt.app.dbgpt_server import system_app
from dbgpt.app.scene.base import ChatScene
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import PromptTemplate
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.core.interface.message import StorageConversation
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.client import DefaultLLMClient
@@ -222,14 +223,11 @@ class MultiAgents(BaseComponent, ABC):
)
)
is_agent_flow = True
if TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode:
from dbgpt.agent import AWELTeamContext
team_context: AWELTeamContext = gpt_app.team_context
if team_context.flow_category == "chat_flow":
is_agent_flow = False
if not is_agent_flow:
if (
TeamMode.AWEL_LAYOUT.value == gpt_app.team_mode
and gpt_app.team_context.flow_category == FlowCategory.CHAT_FLOW
):
team_context = gpt_app.team_context
from dbgpt.core.awel import CommonLLMHttpRequestBody
flow_req = CommonLLMHttpRequestBody(

View File

@@ -6,6 +6,7 @@ from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from dbgpt.component import SystemApp
from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata
from dbgpt.core.awel.flow.flow_factory import FlowCategory
from dbgpt.serve.core import Result
from dbgpt.util import PaginationResult
@@ -172,6 +173,35 @@ async def get_flows(
return Result.succ(flow)
@router.get(
"/chat/flows",
response_model=Result[PaginationResult[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def query_chat_flows(
user_name: Optional[str] = Query(default=None, description="user name"),
sys_code: Optional[str] = Query(default=None, description="system code"),
page: int = Query(default=1, description="current page"),
page_size: int = Query(default=20, description="page size"),
name: Optional[str] = Query(default=None, description="flow name"),
uid: Optional[str] = Query(default=None, description="flow uid"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
return Result.succ(
service.get_list_by_page(
{
"user_name": user_name,
"sys_code": sys_code,
"name": name,
"uid": uid,
"flow_category": [FlowCategory.CHAT_AGENT, FlowCategory.CHAT_FLOW],
},
page,
page_size,
)
)
@router.get(
"/flows",
response_model=Result[PaginationResult[ServerResponse]],

View File

@@ -6,6 +6,7 @@ import schedule
from fastapi import HTTPException
from dbgpt._private.pydantic import model_to_json
from dbgpt.agent import AgentDummyTrigger
from dbgpt.component import SystemApp
from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody
from dbgpt.core.awel.dag.dag_manager import DAGManager
@@ -548,18 +549,28 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
or not isinstance(leaf_nodes[0], BaseOperator)
):
return FlowCategory.COMMON
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs:
return FlowCategory.COMMON
common_http_trigger = False
agent_trigger = False
for trigger in triggers:
if isinstance(trigger, CommonLLMHttpTrigger):
common_http_trigger = True
break
leaf_node = cast(BaseOperator, leaf_nodes[0])
if not leaf_node.metadata or not leaf_node.metadata.outputs:
return FlowCategory.COMMON
if isinstance(trigger, AgentDummyTrigger):
agent_trigger = True
break
output = leaf_node.metadata.outputs[0]
try:
real_class = _get_type_cls(output.type_cls)
if common_http_trigger and is_chat_flow_type(real_class, is_class=True):
if agent_trigger:
return FlowCategory.CHAT_AGENT
elif common_http_trigger and is_chat_flow_type(real_class, is_class=True):
return FlowCategory.CHAT_FLOW
except Exception:
return FlowCategory.COMMON

View File

@@ -16,7 +16,6 @@ REQ = TypeVar("REQ")
# The response schema type
RES = TypeVar("RES")
QUERY_SPEC = Union[REQ, Dict[str, Any]]
@@ -286,10 +285,16 @@ class BaseDao(Generic[T, REQ, RES]):
else model_to_dict(query_request)
)
for key, value in query_dict.items():
if value is not None:
if isinstance(value, (list, tuple, dict, set)):
if value is not None and hasattr(model_cls, key):
if isinstance(value, list):
if len(value) > 0:
query = query.filter(getattr(model_cls, key).in_(value))
else:
continue
elif isinstance(value, (tuple, dict, set)):
continue
query = query.filter(getattr(model_cls, key) == value)
else:
query = query.filter(getattr(model_cls, key) == value)
if desc_order_column:
query = query.order_by(desc(getattr(model_cls, desc_order_column)))