DB-GPT/dbgpt/app/scene/base_chat.py
2024-05-30 18:51:57 +08:00

534 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import datetime
import logging
import traceback
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict
from dbgpt._private.config import Config
from dbgpt._private.pydantic import EXTRA_FORBID
from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene
from dbgpt.app.scene.operators.app_operator import (
AppChatComposerOperator,
ChatComposerInput,
build_cached_chat_operator,
)
from dbgpt.component import ComponentType
from dbgpt.core import LLMClient, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.core.interface.message import StorageConversation
from dbgpt.model import DefaultLLMClient
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.retry import async_retry
from dbgpt.util.tracer import root_tracer, trace
from .exceptions import BaseAppException
logger = logging.getLogger(__name__)
CFG = Config()
def _build_conversation(
chat_mode: ChatScene,
chat_param: Dict[str, Any],
model_name: str,
conv_serve: ConversationServe,
) -> StorageConversation:
param_type = ""
param_value = ""
if chat_param["select_param"]:
if len(chat_mode.param_types()) > 0:
param_type = chat_mode.param_types()[0]
param_value = chat_param["select_param"]
return StorageConversation(
chat_param["chat_session_id"],
chat_mode=chat_mode.value(),
user_name=chat_param.get("user_name"),
sys_code=chat_param.get("sys_code"),
model_name=model_name,
param_type=param_type,
param_value=param_value,
conv_storage=conv_serve.conv_storage,
message_storage=conv_serve.message_storage,
)
class BaseChat(ABC):
"""DB-GPT Chat Service Base Module
Include:
stream_call():scene + prompt -> stream response
nostream_call():scene + prompt -> nostream response
"""
chat_scene: str = None
llm_model: Any = None
# By default, keep the last two rounds of conversation records as the context
keep_start_rounds: int = 0
keep_end_rounds: int = 0
# Some model not support system role, this config is used to control whether to
# convert system message to human message
auto_convert_message: bool = True
@trace("BaseChat.__init__")
def __init__(self, chat_param: Dict):
"""Chat Module Initialization
Args:
- chat_param: Dict
- chat_session_id: (str) chat session_id
- current_user_input: (str) current user input
- model_name:(str) llm model name
- select_param:(str) select param
"""
self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = (
chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL
)
self.llm_echo = False
self.worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
self.model_cache_enable = chat_param.get("model_cache_enable", False)
### load prompt template
# self.prompt_template: PromptTemplate = CFG.prompt_templates[
# self.chat_mode.value()
# ]
self.prompt_template: AppScenePromptTemplateAdapter = (
CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(),
language=CFG.LANGUAGE,
model_name=self.llm_model,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
self._conv_serve = ConversationServe.get_instance(CFG.SYSTEM_APP)
# chat_history_fac = ChatHistory()
### can configurable storage methods
# self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
# self.history_message: List[OnceConversation] = self.memory.messages()
# self.current_message: OnceConversation = OnceConversation(
# self.chat_mode.value(),
# user_name=chat_param.get("user_name"),
# sys_code=chat_param.get("sys_code"),
# )
self.current_message: StorageConversation = _build_conversation(
self.chat_mode, chat_param, self.llm_model, self._conv_serve
)
self.history_messages = self.current_message.get_history_message()
self.current_tokens_used: int = 0
# The executor to submit blocking function
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
# self._model_operator: BaseOperator = _build_model_operator()
# self._model_stream_operator: BaseOperator = _build_model_operator(
# is_stream=True, dag_name="llm_stream_model_dag"
# )
# In v1, we will transform the message to compatible format of specific model
# In the future, we will upgrade the message version to v2, and the message will be compatible with all models
self._message_version = chat_param.get("message_version", "v2")
self._chat_param = chat_param
@property
def chat_type(self) -> str:
raise NotImplementedError("Not supported for this chat type.")
@abstractmethod
async def generate_input_values(self) -> Dict:
"""Generate input to LLM
Please note that you must not perform any blocking operations in this function
Returns:
a dictionary to be formatted by prompt template
"""
@property
def llm_client(self) -> LLMClient:
"""Return the LLM client."""
worker_manager = CFG.SYSTEM_APP.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return DefaultLLMClient(
worker_manager, auto_convert_message=self.auto_convert_message
)
async def call_llm_operator(self, request: ModelRequest) -> ModelOutput:
llm_task = build_cached_chat_operator(self.llm_client, False, CFG.SYSTEM_APP)
return await llm_task.call(call_data=request)
async def call_streaming_operator(
self, request: ModelRequest
) -> AsyncIterator[ModelOutput]:
llm_task = build_cached_chat_operator(self.llm_client, True, CFG.SYSTEM_APP)
async for out in await llm_task.call_stream(call_data=request):
yield out
def do_action(self, prompt_response):
return prompt_response
def message_adjust(self):
pass
def has_history_messages(self) -> bool:
"""Whether there is a history messages
Returns:
bool: True if there is a history message, False otherwise
"""
return len(self.history_messages) > 0
def get_llm_speak(self, prompt_define_response):
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak")
else:
speak_to_user = str(prompt_define_response.thoughts)
else:
if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning")
else:
speak_to_user = prompt_define_response.thoughts
else:
speak_to_user = prompt_define_response
return speak_to_user
async def _build_model_request(self) -> ModelRequest:
input_values = await self.generate_input_values()
# Load history
self.history_messages = self.current_message.get_history_message()
self.current_message.start_new_round()
self.current_message.add_user_message(self.current_user_input)
self.current_message.start_date = datetime.datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.current_message.tokens = 0
keep_start_rounds = (
self.keep_start_rounds
if self.prompt_template.need_historical_messages
else 0
)
keep_end_rounds = (
self.keep_end_rounds if self.prompt_template.need_historical_messages else 0
)
req_ctx = ModelRequestContext(
stream=self.prompt_template.stream_out,
user_name=self._chat_param.get("user_name"),
sys_code=self._chat_param.get("sys_code"),
chat_mode=self.chat_mode.value(),
span_id=root_tracer.get_current_span_id(),
)
node = AppChatComposerOperator(
model=self.llm_model,
temperature=float(self.prompt_template.temperature),
max_new_tokens=int(self.prompt_template.max_new_tokens),
prompt=self.prompt_template.prompt,
message_version=self._message_version,
echo=self.llm_echo,
streaming=self.prompt_template.stream_out,
keep_start_rounds=keep_start_rounds,
keep_end_rounds=keep_end_rounds,
str_history=self.prompt_template.str_history,
request_context=req_ctx,
)
node_input = ChatComposerInput(
messages=self.history_messages, prompt_dict=input_values
)
# llm_messages = self.generate_llm_messages()
model_request: ModelRequest = await node.call(call_data=node_input)
model_request.context.cache_enable = self.model_cache_enable
return model_request
def stream_plugin_call(self, text):
return text
def stream_call_reinforce_fn(self, text):
return text
def _get_span_metadata(self, payload: Dict) -> Dict:
metadata = {k: v for k, v in payload.items()}
del metadata["prompt"]
metadata["messages"] = list(
map(lambda m: m if isinstance(m, dict) else m.dict(), metadata["messages"])
)
return metadata
async def stream_call(self):
# TODO Retry when server connection error
payload = await self._build_model_request()
logger.info(f"payload request: \n{payload}")
ai_response_text = ""
span = root_tracer.start_span(
"BaseChat.stream_call", metadata=payload.to_dict()
)
payload.span_id = span.span_id
try:
async for output in self.call_streaming_operator(payload):
# Plugin research in result generation
msg = self.prompt_template.output_parser.parse_model_stream_resp_ex(
output, 0
)
view_msg = self.stream_plugin_call(msg)
view_msg = view_msg.replace("\n", "\\n")
yield view_msg
self.current_message.add_ai_message(msg)
view_msg = self.stream_call_reinforce_fn(view_msg)
self.current_message.add_view_message(view_msg)
span.end()
except Exception as e:
print(traceback.format_exc())
logger.error("model response parse failed" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### store current conversation
span.end(metadata={"error": str(e)})
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
async def nostream_call(self):
payload = await self._build_model_request()
span = root_tracer.start_span(
"BaseChat.nostream_call", metadata=payload.to_dict()
)
logger.info(f"Request: \n{payload}")
payload.span_id = span.span_id
try:
ai_response_text, view_message = await self._no_streaming_call_with_retry(
payload
)
self.current_message.add_ai_message(ai_response_text)
self.current_message.add_view_message(view_message)
self.message_adjust()
span.end()
except BaseAppException as e:
self.current_message.add_view_message(e.view)
span.end(metadata={"error": str(e)})
except Exception as e:
view_message = f"<span style='color:red'>ERROR!</span> {str(e)}"
self.current_message.add_view_message(view_message)
span.end(metadata={"error": str(e)})
# Store current conversation
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
return self.current_ai_response()
@async_retry(
retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE,
catch_exceptions=(Exception, BaseAppException),
)
async def _no_streaming_call_with_retry(self, payload):
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
model_output = await self.call_llm_operator(payload)
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
)
metadata = {
"model_output": model_output.to_dict(),
"ai_response_text": ai_response_text,
"prompt_define_response": self._parse_prompt_define_response(
prompt_define_response
),
}
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = await blocking_func_to_async(
self._executor,
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response,
)
return ai_response_text, view_message.replace("\n", "\\n")
async def get_llm_response(self):
payload = await self._build_model_request()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
prompt_define_response = None
try:
model_output = await self.call_llm_operator(payload)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
model_output, self.prompt_template.sep
)
)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parse failed" + str(e))
self.current_message.add_view_message(
f"""model response parse failed{str(e)}\n {ai_response_text} """
)
return prompt_define_response
def _blocking_stream_call(self):
logger.warn(
"_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance"
)
loop = get_or_create_event_loop()
async_gen = self.stream_call()
while True:
try:
value = loop.run_until_complete(async_gen.__anext__())
yield value
except StopAsyncIteration:
break
def _blocking_nostream_call(self):
logger.warn(
"_blocking_nostream_call is only temporarily used in webserver and will be deleted soon, please use nostream_call to replace it for higher performance"
)
loop = get_or_create_event_loop()
try:
return loop.run_until_complete(self.nostream_call())
finally:
loop.close()
def call(self):
if self.prompt_template.stream_out:
yield self._blocking_stream_call()
else:
return self._blocking_nostream_call()
async def prepare(self):
pass
def current_ai_response(self) -> str:
for message in self.current_message.messages[-1:]:
if message.type == "view":
return message.content
return None
async def prompt_context_token_adapt(self, prompt) -> str:
"""prompt token adapt according to llm max context length"""
model_metadata = await self.worker_manager.get_model_metadata(
{"model": self.llm_model}
)
current_token_count = await self.worker_manager.count_token(
{"model": self.llm_model, "prompt": prompt}
)
if current_token_count == -1:
logger.warning(
"tiktoken not installed, please `pip install tiktoken` first"
)
template_define_token_count = 0
if len(self.prompt_template.template_define) > 0:
template_define_token_count = await self.worker_manager.count_token(
{
"model": self.llm_model,
"prompt": self.prompt_template.template_define,
}
)
current_token_count += template_define_token_count
if (
current_token_count + self.prompt_template.max_new_tokens
) > model_metadata.context_length:
prompt = prompt[
: (
model_metadata.context_length
- self.prompt_template.max_new_tokens
- template_define_token_count
)
]
return prompt
def generate(self, p) -> str:
"""
generate context for LLM input
Args:
p:
Returns:
"""
pass
def _parse_prompt_define_response(self, prompt_define_response: Any) -> Any:
if not prompt_define_response:
return ""
if isinstance(prompt_define_response, str) or isinstance(
prompt_define_response, dict
):
return prompt_define_response
if isinstance(prompt_define_response, tuple):
if hasattr(prompt_define_response, "_asdict"):
# namedtuple
return prompt_define_response._asdict()
else:
return dict(
zip(range(len(prompt_define_response)), prompt_define_response)
)
else:
return prompt_define_response
def _generate_numbered_list(self) -> str:
"""this function is moved from excel_analyze/chat.py,and used by subclass.
Returns:
"""
antv_charts = [
{"response_line_chart": "used to display comparative trend analysis data"},
{
"response_pie_chart": "suitable for scenarios such as proportion and distribution statistics"
},
{
"response_table": "suitable for display with many display columns or non-numeric columns"
},
# {"response_data_text":" the default display method, suitable for single-line or simple content display"},
{
"response_scatter_plot": "Suitable for exploring relationships between variables, detecting outliers, etc."
},
{
"response_bubble_chart": "Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."
},
{
"response_donut_chart": "Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."
},
{
"response_area_chart": "Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."
},
{
"response_heatmap": "Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."
},
]
return "\n".join(
f"{key}:{value}"
for dict_item in antv_charts
for key, value in dict_item.items()
)