diff --git a/.env.template b/.env.template
index 64b451265..7cfd0ce0f 100644
--- a/.env.template
+++ b/.env.template
@@ -247,4 +247,13 @@ DBGPT_LOG_LEVEL=INFO
#** API_KEYS **#
#*******************************************************************#
# API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas.
-# API_KEYS=dbgpt
\ No newline at end of file
+# API_KEYS=dbgpt
+
+
+#*******************************************************************#
+#** Application Config **#
+#*******************************************************************#
+## Non-streaming scene retries
+# DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE=1
+## Non-streaming scene parallelism
+# DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE=1
\ No newline at end of file
diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py
index 07e1f0a89..b9fa4ea27 100644
--- a/dbgpt/_private/config.py
+++ b/dbgpt/_private/config.py
@@ -300,6 +300,15 @@ class Config(metaclass=Singleton):
# global dbgpt api key
self.API_KEYS = os.getenv("API_KEYS", None)
+ # Non-streaming scene retries
+ self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int(
+ os.getenv("DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE", 1)
+ )
+ # Non-streaming scene parallelism
+ self.DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE = int(
+ os.getenv("DBGPT_APP_SCENE_NON_STREAMING_PARALLELISM_BASE", 1)
+ )
+
@property
def local_db_manager(self) -> "ConnectorManager":
from dbgpt.datasource.manages import ConnectorManager
diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py
index 9ed49f623..d36d1d05e 100644
--- a/dbgpt/app/scene/base_chat.py
+++ b/dbgpt/app/scene/base_chat.py
@@ -21,8 +21,11 @@ from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util import get_or_create_event_loop
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
+from dbgpt.util.retry import async_retry
from dbgpt.util.tracer import root_tracer, trace
+from .exceptions import BaseAppException
+
logger = logging.getLogger(__name__)
CFG = Config()
@@ -321,70 +324,67 @@ class BaseChat(ABC):
"BaseChat.nostream_call", metadata=payload.to_dict()
)
logger.info(f"Request: \n{payload}")
- ai_response_text = ""
payload.span_id = span.span_id
try:
- with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
- model_output = await self.call_llm_operator(payload)
-
- ### output parse
- ai_response_text = (
- self.prompt_template.output_parser.parse_model_nostream_resp(
- model_output, self.prompt_template.sep
- )
+ ai_response_text, view_message = await self._no_streaming_call_with_retry(
+ payload
)
- ### model result deal
self.current_message.add_ai_message(ai_response_text)
- prompt_define_response = (
- self.prompt_template.output_parser.parse_prompt_response(
- ai_response_text
- )
- )
- metadata = {
- "model_output": model_output.to_dict(),
- "ai_response_text": ai_response_text,
- "prompt_define_response": self._parse_prompt_define_response(
- prompt_define_response
- ),
- }
- with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
- ### run
- result = await blocking_func_to_async(
- self._executor, self.do_action, prompt_define_response
- )
-
- ### llm speaker
- speak_to_user = self.get_llm_speak(prompt_define_response)
-
- # view_message = self.prompt_template.output_parser.parse_view_response(
- # speak_to_user, result
- # )
- view_message = await blocking_func_to_async(
- self._executor,
- self.prompt_template.output_parser.parse_view_response,
- speak_to_user,
- result,
- prompt_define_response,
- )
-
- view_message = view_message.replace("\n", "\\n")
self.current_message.add_view_message(view_message)
self.message_adjust()
-
span.end()
- except Exception as e:
- print(traceback.format_exc())
- logger.error("model response parase faild!" + str(e))
- self.current_message.add_view_message(
- f"""ERROR!{str(e)}\n {ai_response_text} """
- )
+ except BaseAppException as e:
+ self.current_message.add_view_message(e.view)
span.end(metadata={"error": str(e)})
- ### store dialogue
+ except Exception as e:
+ view_message = f"ERROR! {str(e)}"
+ self.current_message.add_view_message(view_message)
+ span.end(metadata={"error": str(e)})
+
+ # Store current conversation
await blocking_func_to_async(
self._executor, self.current_message.end_current_round
)
return self.current_ai_response()
+ @async_retry(
+ retries=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
+ parallel_executions=CFG.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE,
+ catch_exceptions=(Exception, BaseAppException),
+ )
+ async def _no_streaming_call_with_retry(self, payload):
+ with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"):
+ model_output = await self.call_llm_operator(payload)
+
+ ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(
+ model_output, self.prompt_template.sep
+ )
+ prompt_define_response = (
+ self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
+ )
+ metadata = {
+ "model_output": model_output.to_dict(),
+ "ai_response_text": ai_response_text,
+ "prompt_define_response": self._parse_prompt_define_response(
+ prompt_define_response
+ ),
+ }
+ with root_tracer.start_span("BaseChat.do_action", metadata=metadata):
+ result = await blocking_func_to_async(
+ self._executor, self.do_action, prompt_define_response
+ )
+
+ speak_to_user = self.get_llm_speak(prompt_define_response)
+
+ view_message = await blocking_func_to_async(
+ self._executor,
+ self.prompt_template.output_parser.parse_view_response,
+ speak_to_user,
+ result,
+ prompt_define_response,
+ )
+ return ai_response_text, view_message.replace("\n", "\\n")
+
async def get_llm_response(self):
payload = await self._build_model_request()
logger.info(f"Request: \n{payload}")
diff --git a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
index d2e1eae96..7aad46bd8 100644
--- a/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
+++ b/dbgpt/app/scene/chat_db/auto_execute/out_parser.py
@@ -9,6 +9,8 @@ from dbgpt._private.config import Config
from dbgpt.core.interface.output_parser import BaseOutputParser
from dbgpt.util.json_utils import serialize
+from ...exceptions import AppActionException
+
CFG = Config()
@@ -66,9 +68,10 @@ class DbChatOutputParser(BaseOutputParser):
param = {}
api_call_element = ET.Element("chart-view")
err_msg = None
+ success = False
try:
if not prompt_response.sql or len(prompt_response.sql) <= 0:
- return f"""{speak}"""
+ raise AppActionException("Can not find sql in response", speak)
df = data(prompt_response.sql)
param["type"] = prompt_response.display
@@ -77,20 +80,26 @@ class DbChatOutputParser(BaseOutputParser):
df.to_json(orient="records", date_format="iso", date_unit="s")
)
view_json_str = json.dumps(param, default=serialize, ensure_ascii=False)
+ success = True
except Exception as e:
logger.error("parse_view_response error!" + str(e))
- err_param = {}
- err_param["sql"] = f"{prompt_response.sql}"
- err_param["type"] = "response_table"
+ err_param = {
+ "sql": f"{prompt_response.sql}",
+ "type": "response_table",
+ "data": [],
+ }
# err_param["err_msg"] = str(e)
- err_param["data"] = []
err_msg = str(e)
view_json_str = json.dumps(err_param, default=serialize, ensure_ascii=False)
# api_call_element.text = view_json_str
api_call_element.set("content", view_json_str)
result = ET.tostring(api_call_element, encoding="utf-8")
- if err_msg:
- return f"""{speak} \\n ERROR!{err_msg} \n {result.decode("utf-8")}"""
+ if not success:
+ view_content = (
+ f'{speak} \\n ERROR!'
+ f"{err_msg} \n {result.decode('utf-8')}"
+ )
+ raise AppActionException("Generate view content failed", view_content)
else:
return speak + "\n" + result.decode("utf-8")
diff --git a/dbgpt/app/scene/exceptions.py b/dbgpt/app/scene/exceptions.py
new file mode 100644
index 000000000..10da0e0db
--- /dev/null
+++ b/dbgpt/app/scene/exceptions.py
@@ -0,0 +1,22 @@
+"""Exceptions for Application."""
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class BaseAppException(Exception):
+ """Base Exception for App"""
+
+ def __init__(self, message: str, view: str):
+ """Base Exception for App"""
+ super().__init__(message)
+ self.message = message
+ self.view = view
+
+
+class AppActionException(BaseAppException):
+ """Exception for App Action."""
+
+ def __init__(self, message: str, view: str):
+ """Exception for App Action"""
+ super().__init__(message, view)
diff --git a/dbgpt/util/retry.py b/dbgpt/util/retry.py
new file mode 100644
index 000000000..c29134c2d
--- /dev/null
+++ b/dbgpt/util/retry.py
@@ -0,0 +1,51 @@
+import asyncio
+import logging
+import traceback
+
+logger = logging.getLogger(__name__)
+
+
+def async_retry(
+ retries: int = 1, parallel_executions: int = 1, catch_exceptions=(Exception,)
+):
+ """Async retry decorator.
+
+ Examples:
+ .. code-block:: python
+
+ @async_retry(retries=3, parallel_executions=2)
+ async def my_func():
+ # Some code that may raise exceptions
+ pass
+
+ Args:
+ retries (int): Number of retries.
+ parallel_executions (int): Number of parallel executions.
+ catch_exceptions (tuple): Tuple of exceptions to catch.
+ """
+
+ def decorator(func):
+ async def wrapper(*args, **kwargs):
+ last_exception = None
+ for attempt in range(retries):
+ tasks = [func(*args, **kwargs) for _ in range(parallel_executions)]
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+
+ for result in results:
+ if not isinstance(result, Exception):
+ return result
+ if isinstance(result, catch_exceptions):
+ last_exception = result
+ logger.error(
+ f"Attempt {attempt + 1} of {retries} failed with error: "
+ f"{type(result).__name__}, {str(result)}"
+ )
+ logger.debug(traceback.format_exc())
+
+ logger.info(f"Retrying... (Attempt {attempt + 1} of {retries})")
+
+ raise last_exception # After all retries, raise the last caught exception
+
+ return wrapper
+
+ return decorator