DB-GPT/dbgpt/experimental/intent/operators.py
2024-05-30 18:51:57 +08:00

93 lines
3.3 KiB
Python

"""Operators for intent detection."""
from typing import Dict, List, Optional, cast
from dbgpt.core import ModelMessage, ModelRequest, ModelRequestContext
from dbgpt.core.awel import BranchFunc, BranchOperator, BranchTaskType, MapOperator
from dbgpt.model.operators.llm_operator import MixinLLMOperator
from .base import BaseIntentDetection, IntentDetectionResponse
class IntentDetectionOperator(
MixinLLMOperator, BaseIntentDetection, MapOperator[ModelRequest, ModelRequest]
):
"""The intent detection operator."""
def __init__(
self,
intent_definitions: str,
prompt_template: Optional[str] = None,
response_format: Optional[str] = None,
examples: Optional[str] = None,
**kwargs
):
"""Create the intent detection operator."""
MixinLLMOperator.__init__(self)
MapOperator.__init__(self, **kwargs)
BaseIntentDetection.__init__(
self,
intent_definitions=intent_definitions,
prompt_template=prompt_template,
response_format=response_format,
examples=examples,
)
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Detect the intent.
Merge the intent detection result into the context.
"""
language = "en"
if self.system_app:
language = self.system_app.config.get_current_lang()
messages = self.parse_messages(input_value)
ic = await self.detect_intent(
messages,
input_value.model,
language=language,
)
if not input_value.context:
input_value.context = ModelRequestContext()
if not input_value.context.extra:
input_value.context.extra = {}
input_value.context.extra["intent_detection"] = ic
return input_value
def parse_messages(self, request: ModelRequest) -> List[ModelMessage]:
"""Parse the messages from the request."""
return request.get_messages()
class IntentDetectionBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
"""The intent detection branch operator."""
def __init__(self, end_task_name: str, **kwargs):
"""Create the intent detection branch operator."""
super().__init__(**kwargs)
self._end_task_name = end_task_name
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], BranchTaskType]:
"""Branch the intent detection result to different tasks."""
download_task_names = set(task.node_name for task in self.downstream) # noqa
branch_func_map = {}
for task_name in download_task_names:
def check(r: ModelRequest, outer_task_name=task_name):
if not r.context or not r.context.extra:
return False
ic_result = r.context.extra.get("intent_detection")
if not ic_result:
return False
ic: IntentDetectionResponse = cast(IntentDetectionResponse, ic_result)
if ic.has_empty_slot():
return self._end_task_name == outer_task_name
else:
return outer_task_name == ic.task_name
branch_func_map[check] = task_name
return branch_func_map # type: ignore