mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-06 03:20:41 +00:00
feat: Support intent detection (#1588)
This commit is contained in:
92
dbgpt/experimental/intent/operators.py
Normal file
92
dbgpt/experimental/intent/operators.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Operators for intent detection."""
|
||||
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from dbgpt.core import ModelMessage, ModelRequest, ModelRequestContext
|
||||
from dbgpt.core.awel import BranchFunc, BranchOperator, BranchTaskType, MapOperator
|
||||
from dbgpt.model.operators.llm_operator import MixinLLMOperator
|
||||
|
||||
from .base import BaseIntentDetection, IntentDetectionResponse
|
||||
|
||||
|
||||
class IntentDetectionOperator(
|
||||
MixinLLMOperator, BaseIntentDetection, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
"""The intent detection operator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
intent_definitions: str,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
examples: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Create the intent detection operator."""
|
||||
MixinLLMOperator.__init__(self)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
BaseIntentDetection.__init__(
|
||||
self,
|
||||
intent_definitions=intent_definitions,
|
||||
prompt_template=prompt_template,
|
||||
response_format=response_format,
|
||||
examples=examples,
|
||||
)
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Detect the intent.
|
||||
|
||||
Merge the intent detection result into the context.
|
||||
"""
|
||||
language = "en"
|
||||
if self.system_app:
|
||||
language = self.system_app.config.get_current_lang()
|
||||
messages = self.parse_messages(input_value)
|
||||
ic = await self.detect_intent(
|
||||
messages,
|
||||
input_value.model,
|
||||
language=language,
|
||||
)
|
||||
if not input_value.context:
|
||||
input_value.context = ModelRequestContext()
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
input_value.context.extra["intent_detection"] = ic
|
||||
return input_value
|
||||
|
||||
def parse_messages(self, request: ModelRequest) -> List[ModelMessage]:
|
||||
"""Parse the messages from the request."""
|
||||
return request.get_messages()
|
||||
|
||||
|
||||
class IntentDetectionBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
"""The intent detection branch operator."""
|
||||
|
||||
def __init__(self, end_task_name: str, **kwargs):
|
||||
"""Create the intent detection branch operator."""
|
||||
super().__init__(**kwargs)
|
||||
self._end_task_name = end_task_name
|
||||
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], BranchTaskType]:
|
||||
"""Branch the intent detection result to different tasks."""
|
||||
download_task_names = set(task.node_name for task in self.downstream) # noqa
|
||||
branch_func_map = {}
|
||||
for task_name in download_task_names:
|
||||
|
||||
def check(r: ModelRequest, outer_task_name=task_name):
|
||||
if not r.context or not r.context.extra:
|
||||
return False
|
||||
ic_result = r.context.extra.get("intent_detection")
|
||||
if not ic_result:
|
||||
return False
|
||||
ic: IntentDetectionResponse = cast(IntentDetectionResponse, ic_result)
|
||||
if ic.has_empty_slot():
|
||||
return self._end_task_name == outer_task_name
|
||||
else:
|
||||
return outer_task_name == ic.task_name
|
||||
|
||||
branch_func_map[check] = task_name
|
||||
|
||||
return branch_func_map # type: ignore
|
Reference in New Issue
Block a user