mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 11:29:29 +00:00
feat: Support retry for 'Chat Data' (#1419)
This commit is contained in:
parent
2e2e120ace
commit
461607e421
@ -248,3 +248,12 @@ DBGPT_LOG_LEVEL=INFO
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas.
|
# API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas.
|
||||||
# API_KEYS=dbgpt
|
# API_KEYS=dbgpt
|
||||||
|
|
||||||
|
|
||||||
|
#*******************************************************************#
|
||||||
|
#** Application Config **#
|
||||||
|
#*******************************************************************#
|
||||||
|
## Non-streaming scene retries
|
||||||
|
# DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1
|
||||||
|
## Non-streaming scene parallelism
|
||||||
|
# DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1
|
@ -300,6 +300,15 @@ class Config(metaclass=Singleton):
|
|||||||
# global dbgpt api key
|
# global dbgpt api key
|
||||||
self.API_KEYS = os.getenv("API_KEYS", None)
|
self.API_KEYS = os.getenv("API_KEYS", None)
|
||||||
|
|
||||||
|
# Non-streaming scene retries
|
||||||
|
self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
|
||||||
|
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1)
|
||||||
|
)
|
||||||
|
# Non-streaming scene parallelism
|
||||||
|
self.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE = int(
|
||||||
|
os.getenv("DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE", 1)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_db_manager(self) -> "ConnectorManager":
|
def local_db_manager(self) -> "ConnectorManager":
|
||||||
from dbgpt.datasource.manages import ConnectorManager
|
from dbgpt.datasource.manages import ConnectorManager
|
||||||
|
@ -21,8 +21,11 @@ from dbgpt.model.cluster import WorkerManagerFactory
|
|||||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||||
from dbgpt.util import get_or_create_event_loop
|
from dbgpt.util import get_or_create_event_loop
|
||||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
from dbgpt.util.retry import async_retry
|
||||||
from dbgpt.util.tracer import root_tracer, trace
|
from dbgpt.util.tracer import root_tracer, trace
|
||||||
|
|
||||||
|
from .exceptions import BaseAppException
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -321,24 +324,43 @@ class BaseChat(ABC):
|
|||||||
"BaseChat.nostream_call", metadata=payload.to_dict()
|
"BaseChat.nostream_call", metadata=payload.to_dict()
|
||||||
)
|
)
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
|
||||||
payload.span_id = span.span_id
|
payload.span_id = span.span_id
|
||||||
try:
|
try:
|
||||||
|
ai_response_text, view_message = await self._no_streaming_call_with_retry(
|
||||||
|
payload
|
||||||
|
)
|
||||||
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
|
self.current_message.add_view_message(view_message)
|
||||||
|
self.message_adjust()
|
||||||
|
span.end()
|
||||||
|
except BaseAppException as e:
|
||||||
|
self.current_message.add_view_message(e.view)
|
||||||
|
span.end(metadata={"error": str(e)})
|
||||||
|
except Exception as e:
|
||||||
|
view_message = f"<span style='color:red'>ERROR!</span> {str(e)}"
|
||||||
|
self.current_message.add_view_message(view_message)
|
||||||
|
span.end(metadata={"error": str(e)})
|
||||||
|
|
||||||
|
# Store current conversation
|
||||||
|
await blocking_func_to_async(
|
||||||
|
self._executor, self.current_message.end_current_round
|
||||||
|
)
|
||||||
|
return self.current_ai_response()
|
||||||
|
|
||||||
|
@async_retry(
|
||||||
|
retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
|
||||||
|
parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
|
||||||
|
catch_exceptions=(Exception, BaseAppException),
|
||||||
|
)
|
||||||
|
async def _no_streaming_call_with_retry(self, payload):
|
||||||
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
|
||||||
model_output = await self.call_llm_operator(payload)
|
model_output = await self.call_llm_operator(payload)
|
||||||
|
|
||||||
### output parse
|
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||||
ai_response_text = (
|
|
||||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
|
||||||
model_output, self.prompt_template.sep
|
model_output, self.prompt_template.sep
|
||||||
)
|
)
|
||||||
)
|
|
||||||
### model result deal
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
|
||||||
prompt_define_response = (
|
prompt_define_response = (
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||||
ai_response_text
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
metadata = {
|
metadata = {
|
||||||
"model_output": model_output.to_dict(),
|
"model_output": model_output.to_dict(),
|
||||||
@ -348,17 +370,12 @@ class BaseChat(ABC):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
|
||||||
### run
|
|
||||||
result = await blocking_func_to_async(
|
result = await blocking_func_to_async(
|
||||||
self._executor, self.do_action, prompt_define_response
|
self._executor, self.do_action, prompt_define_response
|
||||||
)
|
)
|
||||||
|
|
||||||
### llm speaker
|
|
||||||
speak_to_user = self.get_llm_speak(prompt_define_response)
|
speak_to_user = self.get_llm_speak(prompt_define_response)
|
||||||
|
|
||||||
# view_message = self.prompt_template.output_parser.parse_view_response(
|
|
||||||
# speak_to_user, result
|
|
||||||
# )
|
|
||||||
view_message = await blocking_func_to_async(
|
view_message = await blocking_func_to_async(
|
||||||
self._executor,
|
self._executor,
|
||||||
self.prompt_template.output_parser.parse_view_response,
|
self.prompt_template.output_parser.parse_view_response,
|
||||||
@ -366,24 +383,7 @@ class BaseChat(ABC):
|
|||||||
result,
|
result,
|
||||||
prompt_define_response,
|
prompt_define_response,
|
||||||
)
|
)
|
||||||
|
return ai_response_text, view_message.replace("\n", "\\n")
|
||||||
view_message = view_message.replace("\n", "\\n")
|
|
||||||
self.current_message.add_view_message(view_message)
|
|
||||||
self.message_adjust()
|
|
||||||
|
|
||||||
span.end()
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
logger.error("model response parase faild!" + str(e))
|
|
||||||
self.current_message.add_view_message(
|
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
|
||||||
)
|
|
||||||
span.end(metadata={"error": str(e)})
|
|
||||||
### store dialogue
|
|
||||||
await blocking_func_to_async(
|
|
||||||
self._executor, self.current_message.end_current_round
|
|
||||||
)
|
|
||||||
return self.current_ai_response()
|
|
||||||
|
|
||||||
async def get_llm_response(self):
|
async def get_llm_response(self):
|
||||||
payload = await self._build_model_request()
|
payload = await self._build_model_request()
|
||||||
|
@ -9,6 +9,8 @@ from dbgpt._private.config import Config
|
|||||||
from dbgpt.core.interface.output_parser import BaseOutputParser
|
from dbgpt.core.interface.output_parser import BaseOutputParser
|
||||||
from dbgpt.util.json_utils import serialize
|
from dbgpt.util.json_utils import serialize
|
||||||
|
|
||||||
|
from ...exceptions import AppActionException
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
@ -66,9 +68,10 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
param = {}
|
param = {}
|
||||||
api_call_element = ET.Element("chart-view")
|
api_call_element = ET.Element("chart-view")
|
||||||
err_msg = None
|
err_msg = None
|
||||||
|
success = False
|
||||||
try:
|
try:
|
||||||
if not prompt_response.sql or len(prompt_response.sql) <= 0:
|
if not prompt_response.sql or len(prompt_response.sql) <= 0:
|
||||||
return f"""{speak}"""
|
raise AppActionException("Can not find sql in response", speak)
|
||||||
|
|
||||||
df = data(prompt_response.sql)
|
df = data(prompt_response.sql)
|
||||||
param["type"] = prompt_response.display
|
param["type"] = prompt_response.display
|
||||||
@ -77,20 +80,26 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
df.to_json(orient="records", date_format="iso", date_unit="s")
|
df.to_json(orient="records", date_format="iso", date_unit="s")
|
||||||
)
|
)
|
||||||
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
|
||||||
|
success = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("parse_view_response error!" + str(e))
|
logger.error("parse_view_response error!" + str(e))
|
||||||
err_param = {}
|
err_param = {
|
||||||
err_param["sql"] = f"{prompt_response.sql}"
|
"sql": f"{prompt_response.sql}",
|
||||||
err_param["type"] = "response_table"
|
"type": "response_table",
|
||||||
|
"data": [],
|
||||||
|
}
|
||||||
# err_param["err_msg"] = str(e)
|
# err_param["err_msg"] = str(e)
|
||||||
err_param["data"] = []
|
|
||||||
err_msg = str(e)
|
err_msg = str(e)
|
||||||
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
|
||||||
|
|
||||||
# api_call_element.text = view_json_str
|
# api_call_element.text = view_json_str
|
||||||
api_call_element.set("content", view_json_str)
|
api_call_element.set("content", view_json_str)
|
||||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||||
if err_msg:
|
if not success:
|
||||||
return f"""{speak} \\n <span style=\"color:red\">ERROR!</span>{err_msg} \n {result.decode("utf-8")}"""
|
view_content = (
|
||||||
|
f'{speak} \\n <span style="color:red">ERROR!</span>'
|
||||||
|
f"{err_msg} \n {result.decode('utf-8')}"
|
||||||
|
)
|
||||||
|
raise AppActionException("Generate view content failed", view_content)
|
||||||
else:
|
else:
|
||||||
return speak + "\n" + result.decode("utf-8")
|
return speak + "\n" + result.decode("utf-8")
|
||||||
|
22
dbgpt/app/scene/exceptions.py
Normal file
22
dbgpt/app/scene/exceptions.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""Exceptions for Application."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAppException(Exception):
|
||||||
|
"""Base Exception for App"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, view: str):
|
||||||
|
"""Base Exception for App"""
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.view = view
|
||||||
|
|
||||||
|
|
||||||
|
class AppActionException(BaseAppException):
|
||||||
|
"""Exception for App Action."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, view: str):
|
||||||
|
"""Exception for App Action"""
|
||||||
|
super().__init__(message, view)
|
51
dbgpt/util/retry.py
Normal file
51
dbgpt/util/retry.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def async_retry(
|
||||||
|
retries: int = 1, parallel_executions: int = 1, catch_exceptions=(Exception,)
|
||||||
|
):
|
||||||
|
"""Async retry decorator.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@async_retry(retries=3, parallel_executions=2)
|
||||||
|
async def my_func():
|
||||||
|
# Some code that may raise exceptions
|
||||||
|
pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retries (int): Number of retries.
|
||||||
|
parallel_executions (int): Number of parallel executions.
|
||||||
|
catch_exceptions (tuple): Tuple of exceptions to catch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
last_exception = None
|
||||||
|
for attempt in range(retries):
|
||||||
|
tasks = [func(*args, **kwargs) for _ in range(parallel_executions)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
if not isinstance(result, Exception):
|
||||||
|
return result
|
||||||
|
if isinstance(result, catch_exceptions):
|
||||||
|
last_exception = result
|
||||||
|
logger.error(
|
||||||
|
f"Attempt {attempt + 1} of {retries} failed with error: "
|
||||||
|
f"{type(result).__name__}, {str(result)}"
|
||||||
|
)
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
|
||||||
|
logger.info(f"Retrying... (Attempt {attempt + 1} of {retries})")
|
||||||
|
|
||||||
|
raise last_exception # After all retries, raise the last caught exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
Loading…
Reference in New Issue
Block a user