feat(core): Support pass span id to threadpool

This commit is contained in:
FangYin Cheng
2023-11-04 18:08:28 +08:00
parent 23347d52a9
commit 59ac4ee548
14 changed files with 70 additions and 18 deletions

View File

@@ -34,6 +34,7 @@ CREATE TABLE `knowledge_document` (
`content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result',
`result` TEXT NULL COMMENT 'knowledge content', `result` TEXT NULL COMMENT 'knowledge content',
`vector_ids` LONGTEXT NULL COMMENT 'vector_ids', `vector_ids` LONGTEXT NULL COMMENT 'vector_ids',
`summary` LONGTEXT NULL COMMENT 'knowledge summary',
`gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time',
`gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time',
PRIMARY KEY (`id`), PRIMARY KEY (`id`),

View File

@@ -145,7 +145,14 @@ class BaseChat(ABC):
) )
self.current_message.tokens = 0 self.current_message.tokens = 0
if self.prompt_template.template: if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values) metadata = {
"template_scene": self.prompt_template.template_scene,
"input_values": input_values,
}
with root_tracer.start_span(
"BaseChat.__call_base.prompt_template.format", metadata=metadata
):
current_prompt = self.prompt_template.format(**input_values)
self.current_message.add_system_message(current_prompt) self.current_message.add_system_message(current_prompt)
llm_messages = self.generate_llm_messages() llm_messages = self.generate_llm_messages()

View File

@@ -11,6 +11,7 @@ from pilot.common.string_utils import extract_content
from .prompt import prompt from .prompt import prompt
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.base_modules.agent.controller import ModuleAgent from pilot.base_modules.agent.controller import ModuleAgent
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -51,6 +52,7 @@ class ChatAgent(BaseChat):
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
@trace()
async def generate_input_values(self) -> Dict[str, str]: async def generate_input_values(self) -> Dict[str, str]:
input_values = { input_values = {
"user_goal": self.current_user_input, "user_goal": self.current_user_input,
@@ -63,7 +65,10 @@ class ChatAgent(BaseChat):
def stream_plugin_call(self, text): def stream_plugin_call(self, text):
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.api_call.run(text) with root_tracer.start_span(
"ChatAgent.stream_plugin_call.api_call", metadata={"text": text}
):
return self.api_call.run(text)
def __list_to_prompt_str(self, list: List) -> str: def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list)) return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))

View File

@@ -13,6 +13,7 @@ from pilot.scene.chat_dashboard.data_preparation.report_schma import (
from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -53,6 +54,7 @@ class ChatDashboard(BaseChat):
data = f.read() data = f.read()
return json.loads(data) return json.loads(data)
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
try: try:
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient

View File

@@ -14,6 +14,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.common.path_utils import has_path from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status from pilot.base_modules.agent.common.schema import Status
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -62,6 +63,7 @@ class ChatExcel(BaseChat):
# ] # ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
@@ -88,4 +90,9 @@ class ChatExcel(BaseChat):
def stream_plugin_call(self, text): def stream_plugin_call(self, text):
text = text.replace("\n", " ") text = text.replace("\n", " ")
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex) with root_tracer.start_span(
"ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text}
):
return self.api_call.run_display_sql(
text, self.excel_reader.get_df_by_sql_ex
)

View File

@@ -13,6 +13,7 @@ from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder from pilot.json_utils.utilities import DateTimeEncoder
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -44,6 +45,7 @@ class ExcelLearning(BaseChat):
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
# colunms, datas = self.excel_reader.get_sample_data() # colunms, datas = self.excel_reader.get_sample_data()
colunms, datas = await blocking_func_to_async( colunms, datas = await blocking_func_to_async(

View File

@@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -35,10 +36,13 @@ class ChatWithDbAutoExecute(BaseChat):
raise ValueError( raise ValueError(
f"{ChatScene.ChatWithDbExecute.value} mode should chose db!" f"{ChatScene.ChatWithDbExecute.value} mode should chose db!"
) )
with root_tracer.start_span(
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) "ChatWithDbAutoExecute.get_connect", metadata={"db_name": self.db_name}
):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200 self.top_k: int = 200
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
""" """
generate input values generate input values
@@ -55,13 +59,14 @@ class ChatWithDbAutoExecute(BaseChat):
# query=self.current_user_input, # query=self.current_user_input,
# topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, # topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
# ) # )
table_infos = await blocking_func_to_async( with root_tracer.start_span("ChatWithDbAutoExecute.get_db_summary"):
self._executor, table_infos = await blocking_func_to_async(
client.get_db_summary, self._executor,
self.db_name, client.get_db_summary,
self.current_user_input, self.db_name,
CFG.KNOWLEDGE_SEARCH_TOP_SIZE, self.current_user_input,
) CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))
if not table_infos: if not table_infos:
@@ -80,4 +85,8 @@ class ChatWithDbAutoExecute(BaseChat):
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}") print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql) with root_tracer.start_span(
"ChatWithDbAutoExecute.do_action.run_sql",
metadata=prompt_response.to_dict(),
):
return self.database.run(prompt_response.sql)

View File

@@ -12,6 +12,9 @@ class SqlAction(NamedTuple):
sql: str sql: str
thoughts: Dict thoughts: Dict
def to_dict(self) -> Dict[str, Dict]:
return {"sql": self.sql, "thoughts": self.thoughts}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -6,6 +6,7 @@ from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.professional_qa.prompt import prompt from pilot.scene.chat_db.professional_qa.prompt import prompt
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -39,6 +40,7 @@ class ChatWithDbQA(BaseChat):
else len(self.tables) else len(self.tables)
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
table_info = "" table_info = ""
dialect = "mysql" dialect = "mysql"

View File

@@ -6,6 +6,7 @@ from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command import execute_command from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent import PluginPromptGenerator from pilot.base_modules.agent import PluginPromptGenerator
from .prompt import prompt from .prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -50,6 +51,7 @@ class ChatWithPlugin(BaseChat):
self.plugins_prompt_generator self.plugins_prompt_generator
) )
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"input": self.current_user_input, "input": self.current_user_input,

View File

@@ -4,6 +4,7 @@ from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt from pilot.scene.chat_knowledge.inner_db_summary.prompt import prompt
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -31,6 +32,7 @@ class InnerChatDBSummary(BaseChat):
self.db_input = db_select self.db_input = db_select
self.db_summary = db_summary self.db_summary = db_summary
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
input_values = { input_values = {
"db_input": self.db_input, "db_input": self.db_input,

View File

@@ -15,6 +15,7 @@ from pilot.configs.model_config import (
from pilot.scene.chat_knowledge.v1.prompt import prompt from pilot.scene.chat_knowledge.v1.prompt import prompt
from pilot.server.knowledge.service import KnowledgeService from pilot.server.knowledge.service import KnowledgeService
from pilot.utils.executor_utils import blocking_func_to_async from pilot.utils.executor_utils import blocking_func_to_async
from pilot.utils.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -92,6 +93,7 @@ class ChatKnowledge(BaseChat):
"""return reference""" """return reference"""
return text + f"\n\n{self.parse_source_view(self.sources)}" return text + f"\n\n{self.parse_source_view(self.sources)}"
@trace()
async def generate_input_values(self) -> Dict: async def generate_input_values(self) -> Dict:
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]

View File

@@ -1,5 +1,6 @@
from typing import Callable, Awaitable, Any from typing import Callable, Awaitable, Any
import asyncio import asyncio
import contextvars
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial from functools import partial
@@ -55,6 +56,12 @@ async def blocking_func_to_async(
""" """
if asyncio.iscoroutinefunction(func): if asyncio.iscoroutinefunction(func):
raise ValueError(f"The function {func} is not blocking function") raise ValueError(f"The function {func} is not blocking function")
# This function will be called within the new thread, capturing the current context
ctx = contextvars.copy_context()
def run_with_context():
return ctx.run(partial(func, *args, **kwargs))
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
sync_function_noargs = partial(func, *args, **kwargs) return await loop.run_in_executor(executor, run_with_context)
return await loop.run_in_executor(executor, sync_function_noargs)

View File

@@ -303,8 +303,6 @@ def chat(
print(table.get_formatted_string(out_format=output, **out_kwargs)) print(table.get_formatted_string(out_format=output, **out_kwargs))
if sys_table: if sys_table:
print(sys_table.get_formatted_string(out_format=output, **out_kwargs)) print(sys_table.get_formatted_string(out_format=output, **out_kwargs))
if hide_conv:
return
if not found_trace_id: if not found_trace_id:
print(f"Can't found conversation with trace_id: {trace_id}") print(f"Can't found conversation with trace_id: {trace_id}")
@@ -315,9 +313,12 @@ def chat(
trace_spans = [s for s in reversed(trace_spans)] trace_spans = [s for s in reversed(trace_spans)]
hierarchy = _build_trace_hierarchy(trace_spans) hierarchy = _build_trace_hierarchy(trace_spans)
if tree: if tree:
print("\nInvoke Trace Tree:\n") print(f"\nInvoke Trace Tree(trace_id: {trace_id}):\n")
_print_trace_hierarchy(hierarchy) _print_trace_hierarchy(hierarchy)
if hide_conv:
return
trace_spans = _get_ordered_trace_from(hierarchy) trace_spans = _get_ordered_trace_from(hierarchy)
table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details") table = PrettyTable(["Key", "Value Value"], title="Chat Trace Details")
split_long_text = output == "text" split_long_text = output == "text"