DB-GPT/dbgpt/rag/operators/process_branch.py
2024-12-18 11:16:30 +08:00

194 lines
6.3 KiB
Python

"""Knowledge Process Branch Operator."""
from typing import Dict, List, Optional
from dbgpt.core import Chunk
from dbgpt.core.awel import (
BranchFunc,
BranchOperator,
BranchTaskType,
JoinOperator,
logger,
)
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.util.i18n_utils import _
class KnowledgeProcessBranchOperator(BranchOperator[Knowledge, Knowledge]):
"""Knowledge Process branch operator.
This operator will branch the workflow based on
the stream flag of the request.
"""
metadata = ViewMetadata(
label=_("Knowledge Process Branch Operator"),
name="knowledge_process_operator",
category=OperatorCategory.RAG,
operator_type=OperatorType.BRANCH,
description=_("Branch the workflow based on the stream flag of the request."),
parameters=[],
inputs=[
IOField.build_from(
_("Document Chunks"),
"input_value",
List[Chunk],
description=_("The input value of the operator."),
is_list=True,
),
],
outputs=[
IOField.build_from(
_("Chunks"),
"chunks",
List[Chunk],
description=_("Chunks for Vector Storage Connector."),
is_list=True,
),
IOField.build_from(
_("Chunks"),
"chunks",
List[Chunk],
description=_("Chunks for Knowledge Graph Connector."),
is_list=True,
),
IOField.build_from(
_("Chunks"),
"chunks",
List[Chunk],
description=_("Chunks for Full Text Connector."),
is_list=True,
),
],
)
def __init__(
self,
graph_task_name: Optional[str] = None,
embedding_task_name: Optional[str] = None,
**kwargs,
):
"""Create the intent detection branch operator."""
super().__init__(**kwargs)
self._graph_task_name = graph_task_name
self._embedding_task_name = embedding_task_name
self._full_text_task_name = embedding_task_name
async def branches(
self,
) -> Dict[BranchFunc[Knowledge], BranchTaskType]:
"""Branch the intent detection result to different tasks."""
download_cls_list = set(task.__class__ for task in self.downstream) # noqa
branch_func_map = {}
async def check_graph_process(r: Knowledge) -> bool:
# If check graph is true, we will run extract knowledge graph triplets.
from dbgpt.rag.operators import KnowledgeGraphOperator
if KnowledgeGraphOperator in download_cls_list:
return True
return False
async def check_embedding_process(r: Knowledge) -> bool:
# If check embedding is true, we will run extract document embedding.
from dbgpt.rag.operators import VectorStorageOperator
if VectorStorageOperator in download_cls_list:
return True
return False
async def check_full_text_process(r: Knowledge) -> bool:
# If check full text is true, we will run extract document keywords.
from dbgpt.rag.operators.full_text import FullTextStorageOperator
if FullTextStorageOperator in download_cls_list:
return True
return False
branch_func_map[check_graph_process] = self._graph_task_name
branch_func_map[check_embedding_process] = self._embedding_task_name
branch_func_map[check_full_text_process] = self._full_text_task_name
return branch_func_map # type: ignore
class KnowledgeProcessJoinOperator(JoinOperator[List[str]]):
"""Knowledge Process Join Operator."""
metadata = ViewMetadata(
label=_("Knowledge Process Join Operator"),
name="knowledge_process_join_operator",
category=OperatorCategory.RAG,
operator_type=OperatorType.JOIN,
description=_(
"Join Branch the workflow based on the Knowledge Process Results."
),
parameters=[],
inputs=[
IOField.build_from(
_("Vector Storage Results"),
"input_value",
List[Chunk],
description=_("vector storage results."),
is_list=True,
),
IOField.build_from(
_("Knowledge Graph Storage Results"),
"input_value",
List[Chunk],
description=_("knowledge graph storage results."),
is_list=True,
),
],
outputs=[
IOField.build_from(
_("Chunks"),
"chunks",
List[Chunk],
description=_("Knowledge Process Results."),
is_list=True,
),
],
)
def __init__(
self,
**kwargs,
):
"""Knowledge Process Join Operator."""
super().__init__(combine_function=self._join, **kwargs)
async def _join(
self,
vector_chunks: Optional[List[Chunk]] = None,
graph_chunks: Optional[List[Chunk]] = None,
full_text_chunks: Optional[List[Chunk]] = None,
) -> List[str]:
"""Join results.
Args:
vector_chunks: The list of vector chunks.
graph_chunks: The list of graph chunks.
full_text_chunks: The list of full text chunks.
"""
results = []
if vector_chunks:
result_msg = (
f"async persist vector store success {len(vector_chunks)} chunks."
)
logger.info(result_msg)
results.append(result_msg)
if graph_chunks:
result_msg = (
f"async persist graph store success {len(graph_chunks)} chunks."
)
logger.info(result_msg)
results.append(result_msg)
if full_text_chunks:
result_msg = (
f"async persist full text store success {len(full_text_chunks)} "
f"chunks."
)
logger.info(result_msg)
results.append(result_msg)
return results