mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
feat: Support mappers in inputs and outputs
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBResource
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAGContext, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
@@ -193,6 +194,19 @@ class GPTVisMixin:
|
||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
"""Retrieve the table schemas from the datasource."""
|
||||
|
||||
_share_data_key = "__datasource_retriever_chunks__"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
schema_info = await self.current_dag_context.get_from_share_data(
|
||||
HODatasourceRetrieverOperator._share_data_key
|
||||
)
|
||||
if isinstance(schema_info, list):
|
||||
chunks = [Chunk(content=table_info) for table_info in schema_info]
|
||||
else:
|
||||
chunks = [Chunk(content=schema_info)]
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Retriever Operator"),
|
||||
name="higher_order_datasource_retriever_operator",
|
||||
@@ -207,7 +221,17 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
],
|
||||
inputs=[_INPUTS_QUESTION.new()],
|
||||
outputs=[_OUTPUTS_CONTEXT.new()],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Retrieved schema chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved schema chunks from the datasource"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
@@ -239,6 +263,9 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
db=db_name,
|
||||
question=question,
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self._share_data_key, schema_info
|
||||
)
|
||||
context = self._prompt_template.format(
|
||||
db_name=db_name,
|
||||
table_info=schema_info,
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
@@ -93,6 +94,15 @@ _OUTPUTS_CONTEXT = IOField.build_from(
|
||||
|
||||
|
||||
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
_share_data_key = "_higher_order_knowledge_operator_retriever_chunks"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
chunks = await self.current_dag_context.get_from_share_data(
|
||||
HOKnowledgeOperator._share_data_key
|
||||
)
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Knowledge Operator"),
|
||||
name="higher_order_knowledge_operator",
|
||||
@@ -122,6 +132,14 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved chunks from the knowledge space"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
@@ -185,6 +203,7 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
chunks = await self._space_retriever.aretrieve_with_scores(
|
||||
query, self._score_threshold
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(self._share_data_key, chunks)
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=[chunk.content for chunk in chunks],
|
||||
|
@@ -619,6 +619,7 @@ class DAGContext:
|
||||
self._node_name_to_ids: Dict[str, str] = node_name_to_ids
|
||||
self._event_loop_task_id = event_loop_task_id
|
||||
self._dag_variables = dag_variables
|
||||
self._share_data_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _task_outputs(self) -> Dict[str, TaskContext]:
|
||||
@@ -680,6 +681,7 @@ class DAGContext:
|
||||
Returns:
|
||||
Any: The share data, you can cast it to the real type
|
||||
"""
|
||||
async with self._share_data_lock:
|
||||
logger.debug(f"Get share data by key {key} from {id(self._share_data)}")
|
||||
return self._share_data.get(key)
|
||||
|
||||
@@ -694,6 +696,7 @@ class DAGContext:
|
||||
overwrite (bool): Whether overwrite the share data if the key
|
||||
already exists. Defaults to None.
|
||||
"""
|
||||
async with self._share_data_lock:
|
||||
if key in self._share_data and not overwrite:
|
||||
raise ValueError(f"Share data key {key} already exists")
|
||||
logger.debug(f"Save share data by key {key} to {id(self._share_data)}")
|
||||
|
@@ -687,6 +687,10 @@ class IOField(Resource):
|
||||
" True",
|
||||
examples=[0, 1, 2],
|
||||
)
|
||||
mappers: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="The mappers of the field, transform the field to the target type",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@@ -698,10 +702,16 @@ class IOField(Resource):
|
||||
is_list: bool = False,
|
||||
dynamic: bool = False,
|
||||
dynamic_minimum: int = 0,
|
||||
mappers: Optional[Union[Type, List[Type]]] = None,
|
||||
):
|
||||
"""Build the resource from the type."""
|
||||
type_name = type.__qualname__
|
||||
type_cls = _get_type_name(type)
|
||||
# TODO: Check the mapper instance can be created without required
|
||||
# parameters.
|
||||
if mappers and not isinstance(mappers, list):
|
||||
mappers = [mappers]
|
||||
mappers_cls = [_get_type_name(m) for m in mappers] if mappers else None
|
||||
return cls(
|
||||
label=label,
|
||||
name=name,
|
||||
@@ -711,6 +721,7 @@ class IOField(Resource):
|
||||
description=description or label,
|
||||
dynamic=dynamic,
|
||||
dynamic_minimum=dynamic_minimum,
|
||||
mappers=mappers_cls,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
@@ -1,10 +1,11 @@
|
||||
"""Build AWEL DAGs from serialized data."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -565,6 +566,17 @@ class FlowPanel(BaseModel):
|
||||
return [FlowVariables(**v) for v in variables]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _KeyToNodeItem:
|
||||
"""Key to node item."""
|
||||
|
||||
key: str
|
||||
source_order: int
|
||||
target_order: int
|
||||
mappers: List[str]
|
||||
edge_index: int
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Flow factory."""
|
||||
|
||||
@@ -580,8 +592,10 @@ class FlowFactory:
|
||||
key_to_operator_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource_nodes: Dict[str, FlowNodeData] = {}
|
||||
key_to_resource: Dict[str, ResourceMetadata] = {}
|
||||
key_to_downstream: Dict[str, List[Tuple[str, int, int]]] = {}
|
||||
key_to_upstream: Dict[str, List[Tuple[str, int, int]]] = {}
|
||||
# Record current node's downstream
|
||||
key_to_downstream: Dict[str, List[_KeyToNodeItem]] = {}
|
||||
# Record current node's upstream
|
||||
key_to_upstream: Dict[str, List[_KeyToNodeItem]] = {}
|
||||
key_to_upstream_node: Dict[str, List[FlowNodeData]] = {}
|
||||
for node in flow_data.nodes:
|
||||
key = node.id
|
||||
@@ -595,7 +609,7 @@ class FlowFactory:
|
||||
key_to_resource_nodes[key] = node
|
||||
key_to_resource[key] = node.data
|
||||
|
||||
for edge in flow_data.edges:
|
||||
for edge_index, edge in enumerate(flow_data.edges):
|
||||
source_key = edge.source
|
||||
target_key = edge.target
|
||||
source_node: FlowNodeData | None = key_to_operator_nodes.get(
|
||||
@@ -615,12 +629,37 @@ class FlowFactory:
|
||||
|
||||
if source_node.data.is_operator and target_node.data.is_operator:
|
||||
# Operator to operator.
|
||||
mappers = []
|
||||
for i, out in enumerate(source_node.data.outputs):
|
||||
if i != edge.source_order:
|
||||
continue
|
||||
if out.mappers:
|
||||
# Current edge is a mapper edge, find the mappers.
|
||||
mappers = out.mappers
|
||||
# Note: Not support mappers in the inputs of the target node now.
|
||||
|
||||
downstream = key_to_downstream.get(source_key, [])
|
||||
downstream.append((target_key, edge.source_order, edge.target_order))
|
||||
downstream.append(
|
||||
_KeyToNodeItem(
|
||||
key=target_key,
|
||||
source_order=edge.source_order,
|
||||
target_order=edge.target_order,
|
||||
mappers=mappers,
|
||||
edge_index=edge_index,
|
||||
)
|
||||
)
|
||||
key_to_downstream[source_key] = downstream
|
||||
|
||||
upstream = key_to_upstream.get(target_key, [])
|
||||
upstream.append((source_key, edge.source_order, edge.target_order))
|
||||
upstream.append(
|
||||
_KeyToNodeItem(
|
||||
key=source_key,
|
||||
source_order=edge.source_order,
|
||||
target_order=edge.target_order,
|
||||
mappers=mappers,
|
||||
edge_index=edge_index,
|
||||
)
|
||||
)
|
||||
key_to_upstream[target_key] = upstream
|
||||
elif not source_node.data.is_operator and target_node.data.is_operator:
|
||||
# Resource to operator.
|
||||
@@ -678,10 +717,10 @@ class FlowFactory:
|
||||
# Sort the keys by the order of the nodes.
|
||||
for key, value in key_to_downstream.items():
|
||||
# Sort by source_order.
|
||||
key_to_downstream[key] = sorted(value, key=lambda x: x[1])
|
||||
key_to_downstream[key] = sorted(value, key=lambda x: x.source_order)
|
||||
for key, value in key_to_upstream.items():
|
||||
# Sort by target_order.
|
||||
key_to_upstream[key] = sorted(value, key=lambda x: x[2])
|
||||
key_to_upstream[key] = sorted(value, key=lambda x: x.target_order)
|
||||
|
||||
sorted_key_to_resource_nodes = list(key_to_resource_nodes.values())
|
||||
sorted_key_to_resource_nodes = sorted(
|
||||
@@ -779,8 +818,8 @@ class FlowFactory:
|
||||
self,
|
||||
flow_panel: FlowPanel,
|
||||
key_to_tasks: Dict[str, DAGNode],
|
||||
key_to_downstream: Dict[str, List[Tuple[str, int, int]]],
|
||||
key_to_upstream: Dict[str, List[Tuple[str, int, int]]],
|
||||
key_to_downstream: Dict[str, List[_KeyToNodeItem]],
|
||||
key_to_upstream: Dict[str, List[_KeyToNodeItem]],
|
||||
dag_id: Optional[str] = None,
|
||||
) -> DAG:
|
||||
"""Build the DAG."""
|
||||
@@ -827,7 +866,8 @@ class FlowFactory:
|
||||
|
||||
# This upstream has been sorted according to the order in the downstream
|
||||
# So we just need to connect the task to the upstream.
|
||||
for upstream_key, _, _ in upstream:
|
||||
for up_item in upstream:
|
||||
upstream_key = up_item.key
|
||||
# Just one direction.
|
||||
upstream_task = key_to_tasks.get(upstream_key)
|
||||
if not upstream_task:
|
||||
@@ -838,7 +878,13 @@ class FlowFactory:
|
||||
upstream_task.set_node_id(dag._new_node_id())
|
||||
if upstream_task is None:
|
||||
raise ValueError("Unable to find upstream task.")
|
||||
upstream_task >> task
|
||||
tasks = _build_mapper_operators(dag, up_item.mappers)
|
||||
tasks.append(task)
|
||||
last_task = upstream_task
|
||||
for t in tasks:
|
||||
# Connect the task to the upstream task.
|
||||
last_task >> t
|
||||
last_task = t
|
||||
return dag
|
||||
|
||||
def pre_load_requirements(self, flow_panel: FlowPanel):
|
||||
@@ -945,6 +991,23 @@ def _topological_sort(
|
||||
return key_to_order
|
||||
|
||||
|
||||
def _build_mapper_operators(dag: DAG, mappers: List[str]) -> List[DAGNode]:
|
||||
from .base import _get_type_cls
|
||||
|
||||
tasks = []
|
||||
for mapper in mappers:
|
||||
try:
|
||||
mapper_cls = _get_type_cls(mapper)
|
||||
task = mapper_cls()
|
||||
if not task._node_id:
|
||||
task.set_node_id(dag._new_node_id())
|
||||
tasks.append(task)
|
||||
except Exception as e:
|
||||
err_msg = f"Unable to build mapper task: {mapper}, error: {e}"
|
||||
raise FlowMetadataException(err_msg)
|
||||
return tasks
|
||||
|
||||
|
||||
def fill_flow_panel(flow_panel: FlowPanel):
|
||||
"""Fill the flow panel with the latest metadata.
|
||||
|
||||
@@ -973,6 +1036,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
i.mappers = new_param.mappers
|
||||
for i in node.data.outputs:
|
||||
if i.name in output_parameters:
|
||||
new_param = output_parameters[i.name]
|
||||
@@ -981,6 +1045,7 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
i.mappers = new_param.mappers
|
||||
else:
|
||||
data = cast(ResourceMetadata, node.data)
|
||||
key = data.get_origin_id()
|
||||
|
@@ -945,6 +945,16 @@ class StringHttpTrigger(HttpTrigger):
|
||||
class CommonLLMHttpTrigger(HttpTrigger):
|
||||
"""Common LLM http trigger for AWEL."""
|
||||
|
||||
class MessagesOutputMapper(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
"""Messages output mapper."""
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
|
||||
"""Map the request body to messages."""
|
||||
if isinstance(request_body.messages, str):
|
||||
return request_body.messages
|
||||
else:
|
||||
raise ValueError("Messages to be transformed is not a string")
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Common LLM Http Trigger"),
|
||||
name="common_llm_http_trigger",
|
||||
@@ -965,6 +975,16 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
"LLM http body"
|
||||
),
|
||||
),
|
||||
IOField.build_from(
|
||||
_("Request String Messages"),
|
||||
"request_string_messages",
|
||||
str,
|
||||
description=_(
|
||||
"The request string messages of the API endpoint, parsed from "
|
||||
"'messages' field of the request body"
|
||||
),
|
||||
mappers=[MessagesOutputMapper],
|
||||
),
|
||||
],
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
|
@@ -388,7 +388,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
|
||||
Returns:
|
||||
List[ServerResponse]: The response
|
||||
"""
|
||||
page_result = self.dao.get_list_page(request, page, page_size)
|
||||
page_result = self.dao.get_list_page(
|
||||
request, page, page_size, desc_order_column=ServeEntity.gmt_modified.name
|
||||
)
|
||||
for item in page_result.items:
|
||||
metadata = self.dag_manager.get_dag_metadata(
|
||||
item.dag_id, alias_name=item.uid
|
||||
|
@@ -852,7 +852,7 @@ class ExampleFlowUploadOperator(MapOperator[str, str]):
|
||||
ui=ui.UIUpload(
|
||||
max_file_size=1024 * 1024 * 100,
|
||||
up_event="button_click",
|
||||
file_types=["image/*", "*.pdf"],
|
||||
file_types=["image/*", ".pdf"],
|
||||
drag=True,
|
||||
attr=ui.UIUpload.UIAttribute(max_count=5),
|
||||
),
|
||||
@@ -897,7 +897,7 @@ class ExampleFlowUploadOperator(MapOperator[str, str]):
|
||||
files_metadata = await self.blocking_func_to_async(
|
||||
self._parse_files_metadata, fsc
|
||||
)
|
||||
files_metadata_str = json.dumps(files_metadata, ensure_ascii=False)
|
||||
files_metadata_str = json.dumps(files_metadata, ensure_ascii=False, indent=4)
|
||||
return "Your name is %s, and you files are %s." % (
|
||||
user_name,
|
||||
files_metadata_str,
|
||||
|
Reference in New Issue
Block a user