mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
feat: Support mappers in inputs and outputs
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.agent.resource.database import DBResource
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import DAGContext, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
@@ -193,6 +194,19 @@ class GPTVisMixin:
|
||||
class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
"""Retrieve the table schemas from the datasource."""
|
||||
|
||||
_share_data_key = "__datasource_retriever_chunks__"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
schema_info = await self.current_dag_context.get_from_share_data(
|
||||
HODatasourceRetrieverOperator._share_data_key
|
||||
)
|
||||
if isinstance(schema_info, list):
|
||||
chunks = [Chunk(content=table_info) for table_info in schema_info]
|
||||
else:
|
||||
chunks = [Chunk(content=schema_info)]
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Datasource Retriever Operator"),
|
||||
name="higher_order_datasource_retriever_operator",
|
||||
@@ -207,7 +221,17 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
_PARAMETER_CONTEXT_KEY.new(),
|
||||
],
|
||||
inputs=[_INPUTS_QUESTION.new()],
|
||||
outputs=[_OUTPUTS_CONTEXT.new()],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Retrieved schema chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved schema chunks from the datasource"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
@@ -239,6 +263,9 @@ class HODatasourceRetrieverOperator(MapOperator[str, HOContextBody]):
|
||||
db=db_name,
|
||||
question=question,
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
self._share_data_key, schema_info
|
||||
)
|
||||
context = self._prompt_template.format(
|
||||
db_name=db_name,
|
||||
table_info=schema_info,
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
@@ -93,6 +94,15 @@ _OUTPUTS_CONTEXT = IOField.build_from(
|
||||
|
||||
|
||||
class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
_share_data_key = "_higher_order_knowledge_operator_retriever_chunks"
|
||||
|
||||
class ChunkMapper(MapOperator[HOContextBody, List[Chunk]]):
|
||||
async def map(self, context: HOContextBody) -> List[Chunk]:
|
||||
chunks = await self.current_dag_context.get_from_share_data(
|
||||
HOKnowledgeOperator._share_data_key
|
||||
)
|
||||
return chunks
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("Knowledge Operator"),
|
||||
name="higher_order_knowledge_operator",
|
||||
@@ -122,6 +132,14 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
],
|
||||
outputs=[
|
||||
_OUTPUTS_CONTEXT.new(),
|
||||
IOField.build_from(
|
||||
_("Chunks"),
|
||||
"chunks",
|
||||
Chunk,
|
||||
is_list=True,
|
||||
description=_("The retrieved chunks from the knowledge space"),
|
||||
mappers=[ChunkMapper],
|
||||
),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
@@ -185,6 +203,7 @@ class HOKnowledgeOperator(MapOperator[str, HOContextBody]):
|
||||
chunks = await self._space_retriever.aretrieve_with_scores(
|
||||
query, self._score_threshold
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(self._share_data_key, chunks)
|
||||
return HOContextBody(
|
||||
context_key=self._context_key,
|
||||
context=[chunk.content for chunk in chunks],
|
||||
|
Reference in New Issue
Block a user