feat: Support mappers in inputs and outputs

This commit is contained in:
Fangyin Cheng
2024-08-29 12:03:14 +08:00
parent 1f676b9ebf
commit 439b5b32e2
8 changed files with 169 additions and 22 deletions

View File

@@ -4,6 +4,7 @@ from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.agent.resource.database import DBResource
from dbgpt.core import Chunk
from dbgpt.core.awel import DAGContext, MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
@@ -193,6 +194,19 @@ class GPTVisMixin:
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
"""Retrieve the table schemas from the datasource."""
_share_data_key = "__datasource_retriever_chunks__"
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
async def map(self, context: HOContextBody) -> List[Chunk]:
schema_info = await self.current_dag_context.get_from_share_data(
HODatasourceRetrieverOperator._share_data_key
)
if isinstance(schema_info, list):
chunks = [Chunk(content=table_info) for table_info in schema_info]
else:
chunks = [Chunk(content=schema_info)]
return chunks
metadata = ViewMetadata(
label=_("Datasource Retriever Operator"),
name="higher_order_datasource_retriever_operator",
@@ -207,7 +221,17 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
_PARAMETER_CONTEXT_KEY.new(),
],
inputs=[_INPUTS_QUESTION.new()],
outputs=[_OUTPUTS_CONTEXT.new()],
outputs=[
_OUTPUTS_CONTEXT.new(),
IOField.build_from(
_("Retrieved schema chunks"),
"chunks",
Chunk,
is_list=True,
description=_("The retrieved schema chunks from the datasource"),
mappers=[ChunkMapper],
),
],
tags={"order": TAGS_ORDER_HIGH},
)
@@ -239,6 +263,9 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
db=db_name,
question=question,
)
await self.current_dag_context.save_to_share_data(
self._share_data_key, schema_info
)
context = self._prompt_template.format(
db_name=db_name,
table_info=schema_info,

View File

@@ -1,6 +1,7 @@
from typing import List, Optional
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
@@ -93,6 +94,15 @@ _OUTPUTS_CONTEXT = IOField.build_from(
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
_share_data_key = "_higher_order_knowledge_operator_retriever_chunks"
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
async def map(self, context: HOContextBody) -> List[Chunk]:
chunks = await self.current_dag_context.get_from_share_data(
HOKnowledgeOperator._share_data_key
)
return chunks
metadata = ViewMetadata(
label=_("Knowledge Operator"),
name="higher_order_knowledge_operator",
@@ -122,6 +132,14 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
],
outputs=[
_OUTPUTS_CONTEXT.new(),
IOField.build_from(
_("Chunks"),
"chunks",
Chunk,
is_list=True,
description=_("The retrieved chunks from the knowledge space"),
mappers=[ChunkMapper],
),
],
tags={"order": TAGS_ORDER_HIGH},
)
@@ -185,6 +203,7 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
chunks = await self._space_retriever.aretrieve_with_scores(
query, self._score_threshold
)
await self.current_dag_context.save_to_share_data(self._share_data_key, chunks)
return HOContextBody(
context_key=self._context_key,
context=[chunk.content for chunk in chunks],