mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
feat(awel): New AWEL RAG example
This commit is contained in:
parent
e67d62a785
commit
1801138b62
70
examples/awel/simple_rag_example.py
Normal file
70
examples/awel/simple_rag_example.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
"""AWEL: Simple rag example
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
curl -X POST http://127.0.0.1:5000/api/v1/awel/trigger/examples/simple_rag \
|
||||||
|
-H "Content-Type: application/json" -d '{
|
||||||
|
"conv_uid": "36f0e992-8825-11ee-8638-0242ac150003",
|
||||||
|
"model_name": "proxyllm",
|
||||||
|
"chat_mode": "chat_knowledge",
|
||||||
|
"user_input": "What is DB-GPT?",
|
||||||
|
"select_param": "default"
|
||||||
|
}'
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pilot.awel import HttpTrigger, DAG, MapOperator
|
||||||
|
from pilot.scene.operator._experimental import (
|
||||||
|
ChatContext,
|
||||||
|
PromptManagerOperator,
|
||||||
|
ChatHistoryStorageOperator,
|
||||||
|
ChatHistoryOperator,
|
||||||
|
EmbeddingEngingOperator,
|
||||||
|
BaseChatOperator,
|
||||||
|
)
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.openapi.api_view_model import ConversationVo
|
||||||
|
from pilot.model.base import ModelOutput
|
||||||
|
from pilot.model.operator.model_operator import ModelOperator
|
||||||
|
|
||||||
|
|
||||||
|
class RequestParseOperator(MapOperator[ConversationVo, ChatContext]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: ConversationVo) -> ChatContext:
|
||||||
|
return ChatContext(
|
||||||
|
current_user_input=input_value.user_input,
|
||||||
|
model_name=input_value.model_name,
|
||||||
|
chat_session_id=input_value.conv_uid,
|
||||||
|
select_param=input_value.select_param,
|
||||||
|
chat_scene=ChatScene.ChatKnowledge,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
with DAG("simple_rag_example") as dag:
|
||||||
|
trigger_task = HttpTrigger(
|
||||||
|
"/examples/simple_rag", methods="POST", request_body=ConversationVo
|
||||||
|
)
|
||||||
|
req_parse_task = RequestParseOperator()
|
||||||
|
prompt_task = PromptManagerOperator()
|
||||||
|
history_storage_task = ChatHistoryStorageOperator()
|
||||||
|
history_task = ChatHistoryOperator()
|
||||||
|
embedding_task = EmbeddingEngingOperator()
|
||||||
|
chat_task = BaseChatOperator()
|
||||||
|
model_task = ModelOperator()
|
||||||
|
output_parser_task = MapOperator(lambda out: out.to_dict()["text"])
|
||||||
|
|
||||||
|
(
|
||||||
|
trigger_task
|
||||||
|
>> req_parse_task
|
||||||
|
>> prompt_task
|
||||||
|
>> history_storage_task
|
||||||
|
>> history_task
|
||||||
|
>> embedding_task
|
||||||
|
>> chat_task
|
||||||
|
>> model_task
|
||||||
|
>> output_parser_task
|
||||||
|
)
|
@ -7,6 +7,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
from concurrent.futures import Executor
|
||||||
|
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp
|
||||||
from ..resource.base import ResourceGroup
|
from ..resource.base import ResourceGroup
|
||||||
@ -102,6 +103,7 @@ class DAGVar:
|
|||||||
_thread_local = threading.local()
|
_thread_local = threading.local()
|
||||||
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
_async_local = contextvars.ContextVar("current_dag_stack", default=deque())
|
||||||
_system_app: SystemApp = None
|
_system_app: SystemApp = None
|
||||||
|
_executor: Executor = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enter_dag(cls, dag) -> None:
|
def enter_dag(cls, dag) -> None:
|
||||||
@ -157,6 +159,14 @@ class DAGVar:
|
|||||||
else:
|
else:
|
||||||
cls._system_app = system_app
|
cls._system_app = system_app
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_executor(cls) -> Executor:
|
||||||
|
return cls._executor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_executor(cls, executor: Executor) -> None:
|
||||||
|
cls._executor = executor
|
||||||
|
|
||||||
|
|
||||||
class DAGNode(DependencyMixin, ABC):
|
class DAGNode(DependencyMixin, ABC):
|
||||||
resource_group: Optional[ResourceGroup] = None
|
resource_group: Optional[ResourceGroup] = None
|
||||||
@ -165,9 +175,10 @@ class DAGNode(DependencyMixin, ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dag: Optional["DAG"] = None,
|
dag: Optional["DAG"] = None,
|
||||||
node_id: str = None,
|
node_id: Optional[str] = None,
|
||||||
node_name: str = None,
|
node_name: Optional[str] = None,
|
||||||
system_app: SystemApp = None,
|
system_app: Optional[SystemApp] = None,
|
||||||
|
executor: Optional[Executor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._upstream: List["DAGNode"] = []
|
self._upstream: List["DAGNode"] = []
|
||||||
@ -176,6 +187,7 @@ class DAGNode(DependencyMixin, ABC):
|
|||||||
self._system_app: Optional[SystemApp] = (
|
self._system_app: Optional[SystemApp] = (
|
||||||
system_app or DAGVar.get_current_system_app()
|
system_app or DAGVar.get_current_system_app()
|
||||||
)
|
)
|
||||||
|
self._executor: Optional[Executor] = executor or DAGVar.get_executor()
|
||||||
if not node_id and self._dag:
|
if not node_id and self._dag:
|
||||||
node_id = self._dag._new_node_id()
|
node_id = self._dag._new_node_id()
|
||||||
self._node_id: str = node_id
|
self._node_id: str = node_id
|
||||||
|
@ -14,7 +14,13 @@ from typing import (
|
|||||||
)
|
)
|
||||||
import functools
|
import functools
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp, ComponentType
|
||||||
|
from pilot.utils.executor_utils import (
|
||||||
|
ExecutorFactory,
|
||||||
|
DefaultExecutorFactory,
|
||||||
|
blocking_func_to_async,
|
||||||
|
BlockingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
from ..dag.base import DAGNode, DAGContext, DAGVar, DAG
|
||||||
from ..task.base import (
|
from ..task.base import (
|
||||||
@ -71,6 +77,16 @@ class BaseOperatorMeta(ABCMeta):
|
|||||||
system_app: Optional[SystemApp] = (
|
system_app: Optional[SystemApp] = (
|
||||||
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
kwargs.get("system_app") or DAGVar.get_current_system_app()
|
||||||
)
|
)
|
||||||
|
executor = kwargs.get("executor") or DAGVar.get_executor()
|
||||||
|
if not executor:
|
||||||
|
if system_app:
|
||||||
|
executor = system_app.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
else:
|
||||||
|
executor = DefaultExecutorFactory().create()
|
||||||
|
DAGVar.set_executor(executor)
|
||||||
|
|
||||||
if not task_id and dag:
|
if not task_id and dag:
|
||||||
task_id = dag._new_node_id()
|
task_id = dag._new_node_id()
|
||||||
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner
|
||||||
@ -86,6 +102,8 @@ class BaseOperatorMeta(ABCMeta):
|
|||||||
kwargs["runner"] = runner
|
kwargs["runner"] = runner
|
||||||
if not kwargs.get("system_app"):
|
if not kwargs.get("system_app"):
|
||||||
kwargs["system_app"] = system_app
|
kwargs["system_app"] = system_app
|
||||||
|
if not kwargs.get("executor"):
|
||||||
|
kwargs["executor"] = executor
|
||||||
real_obj = func(self, *args, **kwargs)
|
real_obj = func(self, *args, **kwargs)
|
||||||
return real_obj
|
return real_obj
|
||||||
|
|
||||||
@ -177,6 +195,11 @@ class BaseOperator(DAGNode, ABC, Generic[OUT], metaclass=BaseOperatorMeta):
|
|||||||
out_ctx = await self._runner.execute_workflow(self, call_data)
|
out_ctx = await self._runner.execute_workflow(self, call_data)
|
||||||
return out_ctx.current_task_context.task_output.output_stream
|
return out_ctx.current_task_context.task_output.output_stream
|
||||||
|
|
||||||
|
async def blocking_func_to_async(
|
||||||
|
self, func: BlockingFunction, *args, **kwargs
|
||||||
|
) -> Any:
|
||||||
|
return await blocking_func_to_async(self._executor, func, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def initialize_runner(runner: WorkflowRunner):
|
def initialize_runner(runner: WorkflowRunner):
|
||||||
global default_runner
|
global default_runner
|
||||||
|
@ -67,7 +67,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
node_outputs[node.node_id] = task_ctx
|
node_outputs[node.node_id] = task_ctx
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
f"Begin run operator, node id: {node.node_id}, node name: {node.node_name}, cls: {node}"
|
||||||
)
|
)
|
||||||
await node._run(dag_ctx)
|
await node._run(dag_ctx)
|
||||||
@ -76,7 +76,7 @@ class DefaultWorkflowRunner(WorkflowRunner):
|
|||||||
|
|
||||||
if isinstance(node, BranchOperator):
|
if isinstance(node, BranchOperator):
|
||||||
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
skip_nodes = task_ctx.metadata.get("skip_node_names", [])
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Current is branch operator, skip node names: {skip_nodes}"
|
f"Current is branch operator, skip node names: {skip_nodes}"
|
||||||
)
|
)
|
||||||
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
_skip_current_downstream_by_node_name(node, skip_nodes, skip_node_ids)
|
||||||
|
@ -47,7 +47,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
|||||||
logger.error("init create conversation log error!" + str(e))
|
logger.error("init create conversation log error!" + str(e))
|
||||||
|
|
||||||
def append(self, once_message: OnceConversation) -> None:
|
def append(self, once_message: OnceConversation) -> None:
|
||||||
logger.info(f"db history append: {once_message}")
|
logger.debug(f"db history append: {once_message}")
|
||||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||||
self.chat_seesion_id
|
self.chat_seesion_id
|
||||||
)
|
)
|
||||||
|
@ -143,9 +143,7 @@ def _build_request(model: ProxyModel, params):
|
|||||||
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
proxyllm_backend = proxyllm_backend or "gpt-3.5-turbo"
|
||||||
payloads["model"] = proxyllm_backend
|
payloads["model"] = proxyllm_backend
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Send request to real model {proxyllm_backend}")
|
||||||
f"Send request to real model {proxyllm_backend}, openai_params: {openai_params}"
|
|
||||||
)
|
|
||||||
return history, payloads
|
return history, payloads
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class BaseChat(ABC):
|
|||||||
CFG.prompt_template_registry.get_prompt_template(
|
CFG.prompt_template_registry.get_prompt_template(
|
||||||
self.chat_mode.value(),
|
self.chat_mode.value(),
|
||||||
language=CFG.LANGUAGE,
|
language=CFG.LANGUAGE,
|
||||||
model_name=CFG.LLM_MODEL,
|
model_name=self.llm_model,
|
||||||
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -141,13 +141,7 @@ class BaseChat(ABC):
|
|||||||
return speak_to_user
|
return speak_to_user
|
||||||
|
|
||||||
async def __call_base(self):
|
async def __call_base(self):
|
||||||
import inspect
|
input_values = await self.generate_input_values()
|
||||||
|
|
||||||
input_values = (
|
|
||||||
await self.generate_input_values()
|
|
||||||
if inspect.isawaitable(self.generate_input_values())
|
|
||||||
else self.generate_input_values()
|
|
||||||
)
|
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
@ -379,16 +373,18 @@ class BaseChat(ABC):
|
|||||||
if self.prompt_template.template_define:
|
if self.prompt_template.template_define:
|
||||||
text += self.prompt_template.template_define + self.prompt_template.sep
|
text += self.prompt_template.template_define + self.prompt_template.sep
|
||||||
### Load prompt
|
### Load prompt
|
||||||
text += self.__load_system_message()
|
text += _load_system_message(self.current_message, self.prompt_template)
|
||||||
|
|
||||||
### Load examples
|
### Load examples
|
||||||
text += self.__load_example_messages()
|
text += _load_example_messages(self.prompt_template)
|
||||||
|
|
||||||
### Load History
|
### Load History
|
||||||
text += self.__load_history_messages()
|
text += _load_history_messages(
|
||||||
|
self.prompt_template, self.history_message, self.chat_retention_rounds
|
||||||
|
)
|
||||||
|
|
||||||
### Load User Input
|
### Load User Input
|
||||||
text += self.__load_user_message()
|
text += _load_user_message(self.current_message, self.prompt_template)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def generate_llm_messages(self) -> List[ModelMessage]:
|
def generate_llm_messages(self) -> List[ModelMessage]:
|
||||||
@ -406,137 +402,26 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
### Load prompt
|
### Load prompt
|
||||||
messages += self.__load_system_message(str_message=False)
|
messages += _load_system_message(
|
||||||
|
self.current_message, self.prompt_template, str_message=False
|
||||||
|
)
|
||||||
### Load examples
|
### Load examples
|
||||||
messages += self.__load_example_messages(str_message=False)
|
messages += _load_example_messages(self.prompt_template, str_message=False)
|
||||||
|
|
||||||
### Load History
|
### Load History
|
||||||
messages += self.__load_history_messages(str_message=False)
|
messages += _load_history_messages(
|
||||||
|
self.prompt_template,
|
||||||
|
self.history_message,
|
||||||
|
self.chat_retention_rounds,
|
||||||
|
str_message=False,
|
||||||
|
)
|
||||||
|
|
||||||
### Load User Input
|
### Load User Input
|
||||||
messages += self.__load_user_message(str_message=False)
|
messages += _load_user_message(
|
||||||
|
self.current_message, self.prompt_template, str_message=False
|
||||||
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def __load_system_message(self, str_message: bool = True):
|
|
||||||
system_convs = self.current_message.get_system_conv()
|
|
||||||
system_text = ""
|
|
||||||
system_messages = []
|
|
||||||
for system_conv in system_convs:
|
|
||||||
system_text += (
|
|
||||||
system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
|
||||||
)
|
|
||||||
system_messages.append(
|
|
||||||
ModelMessage(role=system_conv.type, content=system_conv.content)
|
|
||||||
)
|
|
||||||
return system_text if str_message else system_messages
|
|
||||||
|
|
||||||
def __load_user_message(self, str_message: bool = True):
|
|
||||||
user_conv = self.current_message.get_user_conv()
|
|
||||||
user_messages = []
|
|
||||||
if user_conv:
|
|
||||||
user_text = (
|
|
||||||
user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
|
||||||
)
|
|
||||||
user_messages.append(
|
|
||||||
ModelMessage(role=user_conv.type, content=user_conv.content)
|
|
||||||
)
|
|
||||||
return user_text if str_message else user_messages
|
|
||||||
else:
|
|
||||||
raise ValueError("Hi! What do you want to talk about?")
|
|
||||||
|
|
||||||
def __load_example_messages(self, str_message: bool = True):
|
|
||||||
example_text = ""
|
|
||||||
example_messages = []
|
|
||||||
if self.prompt_template.example_selector:
|
|
||||||
for round_conv in self.prompt_template.example_selector.examples():
|
|
||||||
for round_message in round_conv["messages"]:
|
|
||||||
if not round_message["type"] in [
|
|
||||||
ModelMessageRoleType.VIEW,
|
|
||||||
ModelMessageRoleType.SYSTEM,
|
|
||||||
]:
|
|
||||||
message_type = round_message["type"]
|
|
||||||
message_content = round_message["data"]["content"]
|
|
||||||
example_text += (
|
|
||||||
message_type
|
|
||||||
+ ":"
|
|
||||||
+ message_content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
example_messages.append(
|
|
||||||
ModelMessage(role=message_type, content=message_content)
|
|
||||||
)
|
|
||||||
return example_text if str_message else example_messages
|
|
||||||
|
|
||||||
def __load_history_messages(self, str_message: bool = True):
|
|
||||||
history_text = ""
|
|
||||||
history_messages = []
|
|
||||||
if self.prompt_template.need_historical_messages:
|
|
||||||
if self.history_message:
|
|
||||||
logger.info(
|
|
||||||
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
|
||||||
)
|
|
||||||
if len(self.history_message) > self.chat_retention_rounds:
|
|
||||||
for first_message in self.history_message[0]["messages"]:
|
|
||||||
if not first_message["type"] in [
|
|
||||||
ModelMessageRoleType.VIEW,
|
|
||||||
ModelMessageRoleType.SYSTEM,
|
|
||||||
]:
|
|
||||||
message_type = first_message["type"]
|
|
||||||
message_content = first_message["data"]["content"]
|
|
||||||
history_text += (
|
|
||||||
message_type
|
|
||||||
+ ":"
|
|
||||||
+ message_content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
history_messages.append(
|
|
||||||
ModelMessage(role=message_type, content=message_content)
|
|
||||||
)
|
|
||||||
if self.chat_retention_rounds > 1:
|
|
||||||
index = self.chat_retention_rounds - 1
|
|
||||||
for round_conv in self.history_message[-index:]:
|
|
||||||
for round_message in round_conv["messages"]:
|
|
||||||
if not round_message["type"] in [
|
|
||||||
ModelMessageRoleType.VIEW,
|
|
||||||
ModelMessageRoleType.SYSTEM,
|
|
||||||
]:
|
|
||||||
message_type = round_message["type"]
|
|
||||||
message_content = round_message["data"]["content"]
|
|
||||||
history_text += (
|
|
||||||
message_type
|
|
||||||
+ ":"
|
|
||||||
+ message_content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
history_messages.append(
|
|
||||||
ModelMessage(
|
|
||||||
role=message_type, content=message_content
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
### user all history
|
|
||||||
for conversation in self.history_message:
|
|
||||||
for message in conversation["messages"]:
|
|
||||||
### histroy message not have promot and view info
|
|
||||||
if not message["type"] in [
|
|
||||||
ModelMessageRoleType.VIEW,
|
|
||||||
ModelMessageRoleType.SYSTEM,
|
|
||||||
]:
|
|
||||||
message_type = message["type"]
|
|
||||||
message_content = message["data"]["content"]
|
|
||||||
history_text += (
|
|
||||||
message_type
|
|
||||||
+ ":"
|
|
||||||
+ message_content
|
|
||||||
+ self.prompt_template.sep
|
|
||||||
)
|
|
||||||
history_messages.append(
|
|
||||||
ModelMessage(role=message_type, content=message_content)
|
|
||||||
)
|
|
||||||
|
|
||||||
return history_text if str_message else history_messages
|
|
||||||
|
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
for message in self.current_message.messages:
|
for message in self.current_message.messages:
|
||||||
if message.type == "view":
|
if message.type == "view":
|
||||||
@ -656,3 +541,127 @@ def _build_model_operator(
|
|||||||
cache_check_branch_node >> cached_node >> join_node
|
cache_check_branch_node >> cached_node >> join_node
|
||||||
|
|
||||||
return join_node
|
return join_node
|
||||||
|
|
||||||
|
|
||||||
|
def _load_system_message(
|
||||||
|
current_message: OnceConversation,
|
||||||
|
prompt_template: PromptTemplate,
|
||||||
|
str_message: bool = True,
|
||||||
|
):
|
||||||
|
system_convs = current_message.get_system_conv()
|
||||||
|
system_text = ""
|
||||||
|
system_messages = []
|
||||||
|
for system_conv in system_convs:
|
||||||
|
system_text += (
|
||||||
|
system_conv.type + ":" + system_conv.content + prompt_template.sep
|
||||||
|
)
|
||||||
|
system_messages.append(
|
||||||
|
ModelMessage(role=system_conv.type, content=system_conv.content)
|
||||||
|
)
|
||||||
|
return system_text if str_message else system_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _load_user_message(
|
||||||
|
current_message: OnceConversation,
|
||||||
|
prompt_template: PromptTemplate,
|
||||||
|
str_message: bool = True,
|
||||||
|
):
|
||||||
|
user_conv = current_message.get_user_conv()
|
||||||
|
user_messages = []
|
||||||
|
if user_conv:
|
||||||
|
user_text = user_conv.type + ":" + user_conv.content + prompt_template.sep
|
||||||
|
user_messages.append(
|
||||||
|
ModelMessage(role=user_conv.type, content=user_conv.content)
|
||||||
|
)
|
||||||
|
return user_text if str_message else user_messages
|
||||||
|
else:
|
||||||
|
raise ValueError("Hi! What do you want to talk about?")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_example_messages(prompt_template: PromptTemplate, str_message: bool = True):
|
||||||
|
example_text = ""
|
||||||
|
example_messages = []
|
||||||
|
if prompt_template.example_selector:
|
||||||
|
for round_conv in prompt_template.example_selector.examples():
|
||||||
|
for round_message in round_conv["messages"]:
|
||||||
|
if not round_message["type"] in [
|
||||||
|
ModelMessageRoleType.VIEW,
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
]:
|
||||||
|
message_type = round_message["type"]
|
||||||
|
message_content = round_message["data"]["content"]
|
||||||
|
example_text += (
|
||||||
|
message_type + ":" + message_content + prompt_template.sep
|
||||||
|
)
|
||||||
|
example_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
return example_text if str_message else example_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _load_history_messages(
|
||||||
|
prompt_template: PromptTemplate,
|
||||||
|
history_message: List[OnceConversation],
|
||||||
|
chat_retention_rounds: int,
|
||||||
|
str_message: bool = True,
|
||||||
|
):
|
||||||
|
history_text = ""
|
||||||
|
history_messages = []
|
||||||
|
if prompt_template.need_historical_messages:
|
||||||
|
if history_message:
|
||||||
|
logger.info(
|
||||||
|
f"There are already {len(history_message)} rounds of conversations! Will use {chat_retention_rounds} rounds of content as history!"
|
||||||
|
)
|
||||||
|
if len(history_message) > chat_retention_rounds:
|
||||||
|
for first_message in history_message[0]["messages"]:
|
||||||
|
if not first_message["type"] in [
|
||||||
|
ModelMessageRoleType.VIEW,
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
]:
|
||||||
|
message_type = first_message["type"]
|
||||||
|
message_content = first_message["data"]["content"]
|
||||||
|
history_text += (
|
||||||
|
message_type + ":" + message_content + prompt_template.sep
|
||||||
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
if chat_retention_rounds > 1:
|
||||||
|
index = chat_retention_rounds - 1
|
||||||
|
for round_conv in history_message[-index:]:
|
||||||
|
for round_message in round_conv["messages"]:
|
||||||
|
if not round_message["type"] in [
|
||||||
|
ModelMessageRoleType.VIEW,
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
]:
|
||||||
|
message_type = round_message["type"]
|
||||||
|
message_content = round_message["data"]["content"]
|
||||||
|
history_text += (
|
||||||
|
message_type
|
||||||
|
+ ":"
|
||||||
|
+ message_content
|
||||||
|
+ prompt_template.sep
|
||||||
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
### user all history
|
||||||
|
for conversation in history_message:
|
||||||
|
for message in conversation["messages"]:
|
||||||
|
### histroy message not have promot and view info
|
||||||
|
if not message["type"] in [
|
||||||
|
ModelMessageRoleType.VIEW,
|
||||||
|
ModelMessageRoleType.SYSTEM,
|
||||||
|
]:
|
||||||
|
message_type = message["type"]
|
||||||
|
message_content = message["data"]["content"]
|
||||||
|
history_text += (
|
||||||
|
message_type + ":" + message_content + prompt_template.sep
|
||||||
|
)
|
||||||
|
history_messages.append(
|
||||||
|
ModelMessage(role=message_type, content=message_content)
|
||||||
|
)
|
||||||
|
|
||||||
|
return history_text if str_message else history_messages
|
||||||
|
@ -6,7 +6,6 @@ import re
|
|||||||
import sqlparse
|
import sqlparse
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import chardet
|
import chardet
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pyparsing import (
|
from pyparsing import (
|
||||||
CaselessKeyword,
|
CaselessKeyword,
|
||||||
@ -27,6 +26,8 @@ from pyparsing import (
|
|||||||
from pilot.common.pd_utils import csv_colunm_foramt
|
from pilot.common.pd_utils import csv_colunm_foramt
|
||||||
from pilot.common.string_utils import is_chinese_include_number
|
from pilot.common.string_utils import is_chinese_include_number
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def excel_colunm_format(old_name: str) -> str:
|
def excel_colunm_format(old_name: str) -> str:
|
||||||
new_column = old_name.strip()
|
new_column = old_name.strip()
|
||||||
@ -263,7 +264,7 @@ class ExcelReader:
|
|||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
self.file_name_without_extension = os.path.splitext(file_name)[0]
|
self.file_name_without_extension = os.path.splitext(file_name)[0]
|
||||||
encoding, confidence = detect_encoding(file_path)
|
encoding, confidence = detect_encoding(file_path)
|
||||||
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
logger.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
||||||
self.excel_file_name = file_name
|
self.excel_file_name = file_name
|
||||||
self.extension = os.path.splitext(file_name)[1]
|
self.extension = os.path.splitext(file_name)[1]
|
||||||
# read excel file
|
# read excel file
|
||||||
@ -323,7 +324,7 @@ class ExcelReader:
|
|||||||
colunms.append(descrip[0])
|
colunms.append(descrip[0])
|
||||||
return colunms, results.fetchall()
|
return colunms, results.fetchall()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("excel sql run error!", e)
|
logger.error(f"excel sql run error!, {str(e)}")
|
||||||
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
||||||
|
|
||||||
def get_df_by_sql_ex(self, sql):
|
def get_df_by_sql_ex(self, sql):
|
||||||
|
@ -37,7 +37,7 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
|
|
||||||
def parse_prompt_response(self, model_out_text):
|
def parse_prompt_response(self, model_out_text):
|
||||||
clean_str = super().parse_prompt_response(model_out_text)
|
clean_str = super().parse_prompt_response(model_out_text)
|
||||||
logging.info("clean prompt response:", clean_str)
|
logger.info(f"clean prompt response: {clean_str}")
|
||||||
# Compatible with community pure sql output model
|
# Compatible with community pure sql output model
|
||||||
if self.is_sql_statement(clean_str):
|
if self.is_sql_statement(clean_str):
|
||||||
return SqlAction(clean_str, "")
|
return SqlAction(clean_str, "")
|
||||||
@ -51,7 +51,7 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
thoughts = response[key]
|
thoughts = response[key]
|
||||||
return SqlAction(sql, thoughts)
|
return SqlAction(sql, thoughts)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("json load faild")
|
logger.error("json load faild")
|
||||||
return SqlAction("", clean_str)
|
return SqlAction("", clean_str)
|
||||||
|
|
||||||
def parse_view_response(self, speak, data, prompt_response) -> str:
|
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||||
|
@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
|
|||||||
self.user_input = chat_param["current_user_input"]
|
self.user_input = chat_param["current_user_input"]
|
||||||
self.extract_mode = chat_param["select_param"]
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"text": self.user_input,
|
"text": self.user_input,
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
|
|||||||
self.user_input = chat_param["current_user_input"]
|
self.user_input = chat_param["current_user_input"]
|
||||||
self.extract_mode = chat_param["select_param"]
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"text": self.user_input,
|
"text": self.user_input,
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ class ExtractRefineSummary(BaseChat):
|
|||||||
|
|
||||||
self.existing_answer = chat_param["select_param"]
|
self.existing_answer = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
# "context": self.user_input,
|
# "context": self.user_input,
|
||||||
"existing_answer": self.existing_answer,
|
"existing_answer": self.existing_answer,
|
||||||
|
@ -23,7 +23,7 @@ class ExtractSummary(BaseChat):
|
|||||||
|
|
||||||
self.user_input = chat_param["select_param"]
|
self.user_input = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"context": self.user_input,
|
"context": self.user_input,
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
self.current_user_input,
|
self.current_user_input,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
)
|
)
|
||||||
self.sources = self.merge_by_key(
|
self.sources = _merge_by_key(
|
||||||
list(map(lambda doc: doc.metadata, docs)), "source"
|
list(map(lambda doc: doc.metadata, docs)), "source"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,29 +149,6 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
return html
|
return html
|
||||||
|
|
||||||
def merge_by_key(self, data, key):
|
|
||||||
result = {}
|
|
||||||
for item in data:
|
|
||||||
if item.get(key):
|
|
||||||
item_key = os.path.basename(item.get(key))
|
|
||||||
if item_key in result:
|
|
||||||
if "pages" in result[item_key] and "page" in item:
|
|
||||||
result[item_key]["pages"].append(str(item["page"]))
|
|
||||||
elif "page" in item:
|
|
||||||
result[item_key]["pages"] = [
|
|
||||||
result[item_key]["pages"],
|
|
||||||
str(item["page"]),
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
if "page" in item:
|
|
||||||
result[item_key] = {
|
|
||||||
"source": item_key,
|
|
||||||
"pages": [str(item["page"])],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
result[item_key] = {"source": item_key}
|
|
||||||
return list(result.values())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_type(self) -> str:
|
def chat_type(self) -> str:
|
||||||
return ChatScene.ChatKnowledge.value()
|
return ChatScene.ChatKnowledge.value()
|
||||||
@ -179,3 +156,27 @@ class ChatKnowledge(BaseChat):
|
|||||||
def get_space_context(self, space_name):
|
def get_space_context(self, space_name):
|
||||||
service = KnowledgeService()
|
service = KnowledgeService()
|
||||||
return service.get_space_context(space_name)
|
return service.get_space_context(space_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_by_key(data, key):
|
||||||
|
result = {}
|
||||||
|
for item in data:
|
||||||
|
if item.get(key):
|
||||||
|
item_key = os.path.basename(item.get(key))
|
||||||
|
if item_key in result:
|
||||||
|
if "pages" in result[item_key] and "page" in item:
|
||||||
|
result[item_key]["pages"].append(str(item["page"]))
|
||||||
|
elif "page" in item:
|
||||||
|
result[item_key]["pages"] = [
|
||||||
|
result[item_key]["pages"],
|
||||||
|
str(item["page"]),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if "page" in item:
|
||||||
|
result[item_key] = {
|
||||||
|
"source": item_key,
|
||||||
|
"pages": [str(item["page"])],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
result[item_key] = {"source": item_key}
|
||||||
|
return list(result.values())
|
||||||
|
255
pilot/scene/operator/_experimental.py
Normal file
255
pilot/scene/operator/_experimental.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
from typing import Dict, Optional, List, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from pilot.awel import MapOperator
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.message import OnceConversation
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
|
||||||
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
|
|
||||||
|
# TODO move global config
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatContext:
|
||||||
|
current_user_input: str
|
||||||
|
model_name: Optional[str]
|
||||||
|
chat_session_id: Optional[str] = None
|
||||||
|
select_param: Optional[str] = None
|
||||||
|
chat_scene: Optional[ChatScene] = ChatScene.ChatNormal
|
||||||
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
|
chat_retention_rounds: Optional[int] = 0
|
||||||
|
history_storage: Optional[BaseChatHistoryMemory] = None
|
||||||
|
history_manager: Optional["ChatHistoryManager"] = None
|
||||||
|
# The input values for prompt template
|
||||||
|
input_values: Optional[Dict] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
|
||||||
|
def build_model_payload(self) -> Dict:
|
||||||
|
if not self.input_values:
|
||||||
|
raise ValueError("The input value can't be empty")
|
||||||
|
llm_messages = self.history_manager._new_chat(self.input_values)
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
"prompt": "",
|
||||||
|
"messages": llm_messages,
|
||||||
|
"temperature": float(self.prompt_template.temperature),
|
||||||
|
"max_new_tokens": int(self.prompt_template.max_new_tokens),
|
||||||
|
"echo": self.echo,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryManager:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chat_ctx: ChatContext,
|
||||||
|
prompt_template: PromptTemplate,
|
||||||
|
history_storage: BaseChatHistoryMemory,
|
||||||
|
chat_retention_rounds: Optional[int] = 0,
|
||||||
|
) -> None:
|
||||||
|
self._chat_ctx = chat_ctx
|
||||||
|
self.chat_retention_rounds = chat_retention_rounds
|
||||||
|
self.current_message: OnceConversation = OnceConversation(
|
||||||
|
chat_ctx.chat_scene.value()
|
||||||
|
)
|
||||||
|
self.prompt_template = prompt_template
|
||||||
|
self.history_storage: BaseChatHistoryMemory = history_storage
|
||||||
|
self.history_message: List[OnceConversation] = history_storage.messages()
|
||||||
|
self.current_message.model_name = chat_ctx.model_name
|
||||||
|
if chat_ctx.select_param:
|
||||||
|
if len(chat_ctx.chat_scene.param_types()) > 0:
|
||||||
|
self.current_message.param_type = chat_ctx.chat_scene.param_types()[0]
|
||||||
|
self.current_message.param_value = chat_ctx.select_param
|
||||||
|
|
||||||
|
def _new_chat(self, input_values: Dict) -> List[ModelMessage]:
|
||||||
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
|
self.current_message.add_user_message(self._chat_ctx.current_user_input)
|
||||||
|
self.current_message.start_date = datetime.datetime.now().strftime(
|
||||||
|
"%Y-%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
self.current_message.tokens = 0
|
||||||
|
if self.prompt_template.template:
|
||||||
|
current_prompt = self.prompt_template.format(**input_values)
|
||||||
|
self.current_message.add_system_message(current_prompt)
|
||||||
|
return self._generate_llm_messages()
|
||||||
|
|
||||||
|
def _generate_llm_messages(self) -> List[ModelMessage]:
|
||||||
|
from pilot.scene.base_chat import (
|
||||||
|
_load_system_message,
|
||||||
|
_load_example_messages,
|
||||||
|
_load_history_messages,
|
||||||
|
_load_user_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
### Load scene setting or character definition as system message
|
||||||
|
if self.prompt_template.template_define:
|
||||||
|
messages.append(
|
||||||
|
ModelMessage(
|
||||||
|
role=ModelMessageRoleType.SYSTEM,
|
||||||
|
content=self.prompt_template.template_define,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
### Load prompt
|
||||||
|
messages += _load_system_message(
|
||||||
|
self.current_message, self.prompt_template, str_message=False
|
||||||
|
)
|
||||||
|
### Load examples
|
||||||
|
messages += _load_example_messages(self.prompt_template, str_message=False)
|
||||||
|
|
||||||
|
### Load History
|
||||||
|
messages += _load_history_messages(
|
||||||
|
self.prompt_template,
|
||||||
|
self.history_message,
|
||||||
|
self.chat_retention_rounds,
|
||||||
|
str_message=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
### Load User Input
|
||||||
|
messages += _load_user_message(
|
||||||
|
self.current_message, self.prompt_template, str_message=False
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
class PromptManagerOperator(MapOperator[ChatContext, ChatContext]):
|
||||||
|
def __init__(self, prompt_template: PromptTemplate = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._prompt_template = prompt_template
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||||
|
if not self._prompt_template:
|
||||||
|
self._prompt_template: PromptTemplate = (
|
||||||
|
CFG.prompt_template_registry.get_prompt_template(
|
||||||
|
input_value.chat_scene.value(),
|
||||||
|
language=CFG.LANGUAGE,
|
||||||
|
model_name=input_value.model_name,
|
||||||
|
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
input_value.prompt_template = self._prompt_template
|
||||||
|
return input_value
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryStorageOperator(MapOperator[ChatContext, ChatContext]):
|
||||||
|
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._history = history
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||||
|
if self._history:
|
||||||
|
return self._history
|
||||||
|
chat_history_fac = ChatHistory()
|
||||||
|
input_value.history_storage = chat_history_fac.get_store_instance(
|
||||||
|
input_value.chat_session_id
|
||||||
|
)
|
||||||
|
return input_value
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryOperator(MapOperator[ChatContext, ChatContext]):
|
||||||
|
def __init__(self, history: BaseChatHistoryMemory = None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._history = history
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||||
|
history_storage = self._history or input_value.history_storage
|
||||||
|
if not history_storage:
|
||||||
|
from pilot.memory.chat_history.store_type.mem_history import (
|
||||||
|
MemHistoryMemory,
|
||||||
|
)
|
||||||
|
|
||||||
|
history_storage = MemHistoryMemory(input_value.chat_session_id)
|
||||||
|
input_value.history_storage = history_storage
|
||||||
|
input_value.history_manager = ChatHistoryManager(
|
||||||
|
input_value,
|
||||||
|
input_value.prompt_template,
|
||||||
|
history_storage,
|
||||||
|
input_value.chat_retention_rounds,
|
||||||
|
)
|
||||||
|
return input_value
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEngingOperator(MapOperator[ChatContext, ChatContext]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatContext) -> ChatContext:
|
||||||
|
from pilot.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
|
from pilot.scene.chat_knowledge.v1.chat import _merge_by_key
|
||||||
|
|
||||||
|
# TODO, decompose the current operator into some atomic operators
|
||||||
|
knowledge_space = input_value.select_param
|
||||||
|
vector_store_config = {
|
||||||
|
"vector_store_name": knowledge_space,
|
||||||
|
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||||
|
}
|
||||||
|
embedding_factory = self.system_app.get_component(
|
||||||
|
"embedding_factory", EmbeddingFactory
|
||||||
|
)
|
||||||
|
knowledge_embedding_client = EmbeddingEngine(
|
||||||
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||||
|
vector_store_config=vector_store_config,
|
||||||
|
embedding_factory=embedding_factory,
|
||||||
|
)
|
||||||
|
space_context = await self._get_space_context(knowledge_space)
|
||||||
|
top_k = (
|
||||||
|
CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||||
|
if space_context is None
|
||||||
|
else int(space_context["embedding"]["topk"])
|
||||||
|
)
|
||||||
|
max_token = (
|
||||||
|
CFG.KNOWLEDGE_SEARCH_MAX_TOKEN
|
||||||
|
if space_context is None or space_context.get("prompt") is None
|
||||||
|
else int(space_context["prompt"]["max_token"])
|
||||||
|
)
|
||||||
|
input_value.prompt_template.template_is_strict = False
|
||||||
|
if space_context and space_context.get("prompt"):
|
||||||
|
input_value.prompt_template.template_define = space_context["prompt"][
|
||||||
|
"scene"
|
||||||
|
]
|
||||||
|
input_value.prompt_template.template = space_context["prompt"]["template"]
|
||||||
|
|
||||||
|
docs = await self.blocking_func_to_async(
|
||||||
|
knowledge_embedding_client.similar_search,
|
||||||
|
input_value.current_user_input,
|
||||||
|
top_k,
|
||||||
|
)
|
||||||
|
sources = _merge_by_key(list(map(lambda doc: doc.metadata, docs)), "source")
|
||||||
|
if not docs or len(docs) == 0:
|
||||||
|
print("no relevant docs to retrieve")
|
||||||
|
context = "no relevant docs to retrieve"
|
||||||
|
else:
|
||||||
|
context = [d.page_content for d in docs]
|
||||||
|
context = context[:max_token]
|
||||||
|
relations = list(
|
||||||
|
set([os.path.basename(str(d.metadata.get("source", ""))) for d in docs])
|
||||||
|
)
|
||||||
|
input_value.input_values = {
|
||||||
|
"context": context,
|
||||||
|
"question": input_value.current_user_input,
|
||||||
|
"relations": relations,
|
||||||
|
}
|
||||||
|
return input_value
|
||||||
|
|
||||||
|
async def _get_space_context(self, space_name):
|
||||||
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
|
|
||||||
|
service = KnowledgeService()
|
||||||
|
return await self.blocking_func_to_async(service.get_space_context, space_name)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatOperator(MapOperator[ChatContext, Dict]):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def map(self, input_value: ChatContext) -> Dict:
|
||||||
|
return input_value.build_model_payload()
|
Loading…
Reference in New Issue
Block a user