mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
feat(core): Support pass span id to threadpool
This commit is contained in:
@@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
|
|||||||
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
|
||||||
`result` TEXT NULL COMMENT 'knowledge content',
|
`result` TEXT NULL COMMENT 'knowledge content',
|
||||||
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
|
||||||
|
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
|
||||||
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
|
||||||
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
|
||||||
PRIMARY KEY (`id`),
|
PRIMARY KEY (`id`),
|
||||||
|
@@ -145,6 +145,13 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
self.current_message.tokens = 0
|
self.current_message.tokens = 0
|
||||||
if self.prompt_template.template:
|
if self.prompt_template.template:
|
||||||
|
metadata = {
|
||||||
|
"template_scene": self.prompt_template.template_scene,
|
||||||
|
"input_values": input_values,
|
||||||
|
}
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"BaseChat.__call_base.prompt_template.format", metadata=metadata
|
||||||
|
):
|
||||||
current_prompt = self.prompt_template.format(**input_values)
|
current_prompt = self.prompt_template.format(**input_values)
|
||||||
self.current_message.add_system_message(current_prompt)
|
self.current_message.add_system_message(current_prompt)
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ from pilot.common.string_utils import extract_content
|
|||||||
from .prompt import prompt
|
from .prompt import prompt
|
||||||
from pilot.component import ComponentType
|
from pilot.component import ComponentType
|
||||||
from pilot.base_modules.agent.controller import ModuleAgent
|
from pilot.base_modules.agent.controller import ModuleAgent
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ class ChatAgent(BaseChat):
|
|||||||
|
|
||||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict[str, str]:
|
async def generate_input_values(self) -> Dict[str, str]:
|
||||||
input_values = {
|
input_values = {
|
||||||
"user_goal": self.current_user_input,
|
"user_goal": self.current_user_input,
|
||||||
@@ -63,6 +65,9 @@ class ChatAgent(BaseChat):
|
|||||||
|
|
||||||
def stream_plugin_call(self, text):
|
def stream_plugin_call(self, text):
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
|
||||||
|
):
|
||||||
return self.api_call.run(text)
|
return self.api_call.run(text)
|
||||||
|
|
||||||
def __list_to_prompt_str(self, list: List) -> str:
|
def __list_to_prompt_str(self, list: List) -> str:
|
||||||
|
@@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
|||||||
from pilot.scene.chat_dashboard.prompt import prompt
|
from pilot.scene.chat_dashboard.prompt import prompt
|
||||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||||
from pilot.utils.executor_utils import blocking_func_to_async
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ class ChatDashboard(BaseChat):
|
|||||||
data = f.read()
|
data = f.read()
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
try:
|
try:
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
@@ -14,6 +14,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
|||||||
from pilot.common.path_utils import has_path
|
from pilot.common.path_utils import has_path
|
||||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.base_modules.agent.common.schema import Status
|
from pilot.base_modules.agent.common.schema import Status
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -62,6 +63,7 @@ class ChatExcel(BaseChat):
|
|||||||
# ]
|
# ]
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"user_input": self.current_user_input,
|
"user_input": self.current_user_input,
|
||||||
@@ -88,4 +90,9 @@ class ChatExcel(BaseChat):
|
|||||||
|
|
||||||
def stream_plugin_call(self, text):
|
def stream_plugin_call(self, text):
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
|
with root_tracer.start_span(
|
||||||
|
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
|
||||||
|
):
|
||||||
|
return self.api_call.run_display_sql(
|
||||||
|
text, self.excel_reader.get_df_by_sql_ex
|
||||||
|
)
|
||||||
|
@@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
|||||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||||
from pilot.json_utils.utilities import DateTimeEncoder
|
from pilot.json_utils.utilities import DateTimeEncoder
|
||||||
from pilot.utils.executor_utils import blocking_func_to_async
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ class ExcelLearning(BaseChat):
|
|||||||
if parent_mode:
|
if parent_mode:
|
||||||
self.current_message.chat_mode = parent_mode.value()
|
self.current_message.chat_mode = parent_mode.value()
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
# colunms, datas = self.excel_reader.get_sample_data()
|
# colunms, datas = self.excel_reader.get_sample_data()
|
||||||
colunms, datas = await blocking_func_to_async(
|
colunms, datas = await blocking_func_to_async(
|
||||||
|
@@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
from pilot.utils.executor_utils import blocking_func_to_async
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -35,10 +36,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
|
||||||
)
|
)
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
|
||||||
|
):
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
self.top_k: int = 200
|
self.top_k: int = 200
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
"""
|
"""
|
||||||
generate input values
|
generate input values
|
||||||
@@ -55,6 +59,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
# query=self.current_user_input,
|
# query=self.current_user_input,
|
||||||
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
|
||||||
# )
|
# )
|
||||||
|
with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
|
||||||
table_infos = await blocking_func_to_async(
|
table_infos = await blocking_func_to_async(
|
||||||
self._executor,
|
self._executor,
|
||||||
client.get_db_summary,
|
client.get_db_summary,
|
||||||
@@ -80,4 +85,8 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
print(f"do_action:{prompt_response}")
|
print(f"do_action:{prompt_response}")
|
||||||
|
with root_tracer.start_span(
|
||||||
|
"ChatWithDbAutoExecute.do_action.run_sql",
|
||||||
|
metadata=prompt_response.to_dict(),
|
||||||
|
):
|
||||||
return self.database.run(prompt_response.sql)
|
return self.database.run(prompt_response.sql)
|
||||||
|
@@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
|
|||||||
sql: str
|
sql: str
|
||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Dict]:
|
||||||
|
return {"sql": self.sql, "thoughts": self.thoughts}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
from pilot.scene.chat_db.professional_qa.prompt import prompt
|
||||||
from pilot.utils.executor_utils import blocking_func_to_async
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ class ChatWithDbQA(BaseChat):
|
|||||||
else len(self.tables)
|
else len(self.tables)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
table_info = ""
|
table_info = ""
|
||||||
dialect = "mysql"
|
dialect = "mysql"
|
||||||
|
@@ -6,6 +6,7 @@ from pilot.configs.config import Config
|
|||||||
from pilot.base_modules.agent.commands.command import execute_command
|
from pilot.base_modules.agent.commands.command import execute_command
|
||||||
from pilot.base_modules.agent import PluginPromptGenerator
|
from pilot.base_modules.agent import PluginPromptGenerator
|
||||||
from .prompt import prompt
|
from .prompt import prompt
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ class ChatWithPlugin(BaseChat):
|
|||||||
self.plugins_prompt_generator
|
self.plugins_prompt_generator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
"input": self.current_user_input,
|
||||||
|
@@ -4,6 +4,7 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ class InnerChatDBSummary(BaseChat):
|
|||||||
self.db_input = db_select
|
self.db_input = db_select
|
||||||
self.db_summary = db_summary
|
self.db_summary = db_summary
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
"db_input": self.db_input,
|
"db_input": self.db_input,
|
||||||
|
@@ -15,6 +15,7 @@ from pilot.configs.model_config import (
|
|||||||
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
from pilot.scene.chat_knowledge.v1.prompt import prompt
|
||||||
from pilot.server.knowledge.service import KnowledgeService
|
from pilot.server.knowledge.service import KnowledgeService
|
||||||
from pilot.utils.executor_utils import blocking_func_to_async
|
from pilot.utils.executor_utils import blocking_func_to_async
|
||||||
|
from pilot.utils.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@@ -92,6 +93,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
"""return reference"""
|
"""return reference"""
|
||||||
return text + f"\n\n{self.parse_source_view(self.sources)}"
|
return text + f"\n\n{self.parse_source_view(self.sources)}"
|
||||||
|
|
||||||
|
@trace()
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
if self.space_context:
|
if self.space_context:
|
||||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from typing import Callable, Awaitable, Any
|
from typing import Callable, Awaitable, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextvars
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -55,6 +56,12 @@ async def blocking_func_to_async(
|
|||||||
"""
|
"""
|
||||||
if asyncio.iscoroutinefunction(func):
|
if asyncio.iscoroutinefunction(func):
|
||||||
raise ValueError(f"The function {func} is not blocking function")
|
raise ValueError(f"The function {func} is not blocking function")
|
||||||
|
|
||||||
|
# This function will be called within the new thread, capturing the current context
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
|
||||||
|
def run_with_context():
|
||||||
|
return ctx.run(partial(func, *args, **kwargs))
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
sync_function_noargs = partial(func, *args, **kwargs)
|
return await loop.run_in_executor(executor, run_with_context)
|
||||||
return await loop.run_in_executor(executor, sync_function_noargs)
|
|
||||||
|
@@ -303,8 +303,6 @@ def chat(
|
|||||||
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
print(table.get_formatted_string(out_format=output, **out_kwargs))
|
||||||
if sys_table:
|
if sys_table:
|
||||||
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
|
print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
|
||||||
if hide_conv:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not found_trace_id:
|
if not found_trace_id:
|
||||||
print(f"Can't found conversation with trace_id: {trace_id}")
|
print(f"Can't found conversation with trace_id: {trace_id}")
|
||||||
@@ -315,9 +313,12 @@ def chat(
|
|||||||
trace_spans = [s for s in reversed(trace_spans)]
|
trace_spans = [s for s in reversed(trace_spans)]
|
||||||
hierarchy = _build_trace_hierarchy(trace_spans)
|
hierarchy = _build_trace_hierarchy(trace_spans)
|
||||||
if tree:
|
if tree:
|
||||||
print("\nInvoke Trace Tree:\n")
|
print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
|
||||||
_print_trace_hierarchy(hierarchy)
|
_print_trace_hierarchy(hierarchy)
|
||||||
|
|
||||||
|
if hide_conv:
|
||||||
|
return
|
||||||
|
|
||||||
trace_spans = _get_ordered_trace_from(hierarchy)
|
trace_spans = _get_ordered_trace_from(hierarchy)
|
||||||
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
|
||||||
split_long_text = output == "text"
|
split_long_text = output == "text"
|
||||||
|
Reference in New Issue
Block a user