mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-04 01:50:08 +00:00
194 lines
6.3 KiB
Python
194 lines
6.3 KiB
Python
"""Knowledge Process Branch Operator."""
|
|
from typing import Dict, List, Optional
|
|
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.core.awel import (
|
|
BranchFunc,
|
|
BranchOperator,
|
|
BranchTaskType,
|
|
JoinOperator,
|
|
logger,
|
|
)
|
|
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
|
from dbgpt.rag.knowledge.base import Knowledge
|
|
from dbgpt.util.i18n_utils import _
|
|
|
|
|
|
class KnowledgeProcessBranchOperator(BranchOperator[Knowledge, Knowledge]):
|
|
"""Knowledge Process branch operator.
|
|
|
|
This operator will branch the workflow based on
|
|
the stream flag of the request.
|
|
"""
|
|
|
|
metadata = ViewMetadata(
|
|
label=_("Knowledge Process Branch Operator"),
|
|
name="knowledge_process_operator",
|
|
category=OperatorCategory.RAG,
|
|
operator_type=OperatorType.BRANCH,
|
|
description=_("Branch the workflow based on the stream flag of the request."),
|
|
parameters=[],
|
|
inputs=[
|
|
IOField.build_from(
|
|
_("Document Chunks"),
|
|
"input_value",
|
|
List[Chunk],
|
|
description=_("The input value of the operator."),
|
|
is_list=True,
|
|
),
|
|
],
|
|
outputs=[
|
|
IOField.build_from(
|
|
_("Chunks"),
|
|
"chunks",
|
|
List[Chunk],
|
|
description=_("Chunks for Vector Storage Connector."),
|
|
is_list=True,
|
|
),
|
|
IOField.build_from(
|
|
_("Chunks"),
|
|
"chunks",
|
|
List[Chunk],
|
|
description=_("Chunks for Knowledge Graph Connector."),
|
|
is_list=True,
|
|
),
|
|
IOField.build_from(
|
|
_("Chunks"),
|
|
"chunks",
|
|
List[Chunk],
|
|
description=_("Chunks for Full Text Connector."),
|
|
is_list=True,
|
|
),
|
|
],
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
graph_task_name: Optional[str] = None,
|
|
embedding_task_name: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
"""Create the intent detection branch operator."""
|
|
super().__init__(**kwargs)
|
|
self._graph_task_name = graph_task_name
|
|
self._embedding_task_name = embedding_task_name
|
|
self._full_text_task_name = embedding_task_name
|
|
|
|
async def branches(
|
|
self,
|
|
) -> Dict[BranchFunc[Knowledge], BranchTaskType]:
|
|
"""Branch the intent detection result to different tasks."""
|
|
download_cls_list = set(task.__class__ for task in self.downstream) # noqa
|
|
branch_func_map = {}
|
|
|
|
async def check_graph_process(r: Knowledge) -> bool:
|
|
# If check graph is true, we will run extract knowledge graph triplets.
|
|
from dbgpt.rag.operators import KnowledgeGraphOperator
|
|
|
|
if KnowledgeGraphOperator in download_cls_list:
|
|
return True
|
|
return False
|
|
|
|
async def check_embedding_process(r: Knowledge) -> bool:
|
|
# If check embedding is true, we will run extract document embedding.
|
|
from dbgpt.rag.operators import VectorStorageOperator
|
|
|
|
if VectorStorageOperator in download_cls_list:
|
|
return True
|
|
return False
|
|
|
|
async def check_full_text_process(r: Knowledge) -> bool:
|
|
# If check full text is true, we will run extract document keywords.
|
|
from dbgpt.rag.operators.full_text import FullTextStorageOperator
|
|
|
|
if FullTextStorageOperator in download_cls_list:
|
|
return True
|
|
return False
|
|
|
|
branch_func_map[check_graph_process] = self._graph_task_name
|
|
branch_func_map[check_embedding_process] = self._embedding_task_name
|
|
branch_func_map[check_full_text_process] = self._full_text_task_name
|
|
return branch_func_map # type: ignore
|
|
|
|
|
|
class KnowledgeProcessJoinOperator(JoinOperator[List[str]]):
|
|
"""Knowledge Process Join Operator."""
|
|
|
|
metadata = ViewMetadata(
|
|
label=_("Knowledge Process Join Operator"),
|
|
name="knowledge_process_join_operator",
|
|
category=OperatorCategory.RAG,
|
|
operator_type=OperatorType.JOIN,
|
|
description=_(
|
|
"Join Branch the workflow based on the Knowledge Process Results."
|
|
),
|
|
parameters=[],
|
|
inputs=[
|
|
IOField.build_from(
|
|
_("Vector Storage Results"),
|
|
"input_value",
|
|
List[Chunk],
|
|
description=_("vector storage results."),
|
|
is_list=True,
|
|
),
|
|
IOField.build_from(
|
|
_("Knowledge Graph Storage Results"),
|
|
"input_value",
|
|
List[Chunk],
|
|
description=_("knowledge graph storage results."),
|
|
is_list=True,
|
|
),
|
|
],
|
|
outputs=[
|
|
IOField.build_from(
|
|
_("Chunks"),
|
|
"chunks",
|
|
List[Chunk],
|
|
description=_("Knowledge Process Results."),
|
|
is_list=True,
|
|
),
|
|
],
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
**kwargs,
|
|
):
|
|
"""Knowledge Process Join Operator."""
|
|
super().__init__(combine_function=self._join, **kwargs)
|
|
|
|
async def _join(
|
|
self,
|
|
vector_chunks: Optional[List[Chunk]] = None,
|
|
graph_chunks: Optional[List[Chunk]] = None,
|
|
full_text_chunks: Optional[List[Chunk]] = None,
|
|
) -> List[str]:
|
|
"""Join results.
|
|
|
|
Args:
|
|
vector_chunks: The list of vector chunks.
|
|
graph_chunks: The list of graph chunks.
|
|
full_text_chunks: The list of full text chunks.
|
|
"""
|
|
results = []
|
|
if vector_chunks:
|
|
result_msg = (
|
|
f"async persist vector store success {len(vector_chunks)} chunks."
|
|
)
|
|
logger.info(result_msg)
|
|
results.append(result_msg)
|
|
if graph_chunks:
|
|
result_msg = (
|
|
f"async persist graph store success {len(graph_chunks)} chunks."
|
|
)
|
|
logger.info(result_msg)
|
|
results.append(result_msg)
|
|
if full_text_chunks:
|
|
result_msg = (
|
|
f"async persist full text store success {len(full_text_chunks)} "
|
|
f"chunks."
|
|
)
|
|
logger.info(result_msg)
|
|
results.append(result_msg)
|
|
return results
|