mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
style:fmt
This commit is contained in:
@@ -7,16 +7,19 @@ import hashlib
|
||||
from http import HTTPStatus
|
||||
from dashscope import Generation
|
||||
|
||||
|
||||
def call_with_messages():
|
||||
messages = [{'role': 'system', 'content': '你是生活助手机器人。'},
|
||||
{'role': 'user', 'content': '如何做西红柿鸡蛋?'}]
|
||||
messages = [
|
||||
{"role": "system", "content": "你是生活助手机器人。"},
|
||||
{"role": "user", "content": "如何做西红柿鸡蛋?"},
|
||||
]
|
||||
gen = Generation()
|
||||
response = gen.call(
|
||||
Generation.Models.qwen_turbo,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
top_p=0.8,
|
||||
result_format='message', # set the result to be "message" format.
|
||||
result_format="message", # set the result to be "message" format.
|
||||
)
|
||||
|
||||
for response in response:
|
||||
@@ -24,21 +27,24 @@ def call_with_messages():
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
# and message from code and message.
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
print(response.output) # The output text
|
||||
print(response.output) # The output text
|
||||
print(response.usage) # The usage information
|
||||
else:
|
||||
print(response.code) # The error code.
|
||||
print(response.message) # The error message.
|
||||
|
||||
print(response.code) # The error code.
|
||||
print(response.message) # The error message.
|
||||
|
||||
|
||||
def build_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
Generate Access token according AK, SK
|
||||
Generate Access token according AK, SK
|
||||
"""
|
||||
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": api_key,
|
||||
"client_secret": secret_key,
|
||||
}
|
||||
|
||||
res = requests.get(url=url, params=params)
|
||||
|
||||
@@ -47,15 +53,15 @@ def build_access_token(api_key: str, secret_key: str) -> str:
|
||||
|
||||
|
||||
def _calculate_md5(text: str) -> str:
|
||||
|
||||
md5 = hashlib.md5()
|
||||
md5.update(text.encode("utf-8"))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
def baichuan_call():
|
||||
url = "https://api.baichuan-ai.com/v1/stream/chat"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
call_with_messages()
|
@@ -1,4 +1,4 @@
|
||||
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
|
||||
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
|
||||
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
|
||||
from .commands.command import execute_command, get_command
|
||||
|
@@ -9,6 +9,7 @@ from .generator import PluginPromptGenerator
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
|
||||
def _resolve_pathlike_command_args(command_args):
|
||||
if "directory" in command_args and command_args["directory"] in {"", "/"}:
|
||||
# todo
|
||||
@@ -64,8 +65,6 @@ def execute_ai_response_json(
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
def execute_command(
|
||||
command_name: str,
|
||||
arguments,
|
||||
@@ -81,10 +80,8 @@ def execute_command(
|
||||
str: The result of the command
|
||||
"""
|
||||
|
||||
|
||||
cmd = plugin_generator.command_registry.commands.get(command_name)
|
||||
|
||||
|
||||
# If the command is found, call it with the provided arguments
|
||||
if cmd:
|
||||
try:
|
||||
@@ -153,6 +150,3 @@ def get_command(response_json: Dict):
|
||||
# All other errors, return "Error: + error message"
|
||||
except Exception as e:
|
||||
return "Error:", str(e)
|
||||
|
||||
|
||||
|
||||
|
@@ -28,13 +28,13 @@ class Command:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
method: Callable[..., Any],
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
@@ -87,11 +87,12 @@ class CommandRegistry:
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def is_valid_command(self, name:str)-> bool:
|
||||
def is_valid_command(self, name: str) -> bool:
|
||||
if name not in self.commands:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
|
||||
@@ -129,23 +130,23 @@ class CommandRegistry:
|
||||
attr = getattr(module, attr_name)
|
||||
# Register decorated functions
|
||||
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
attr, AUTO_GPT_COMMAND_IDENTIFIER
|
||||
):
|
||||
self.register(attr.command)
|
||||
# Register command classes
|
||||
elif (
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
|
||||
):
|
||||
cmd_instance = attr()
|
||||
self.register(cmd_instance)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
name: str,
|
||||
description: str,
|
||||
signature: str = "",
|
||||
enabled: bool = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""The command decorator is used to create Command objects from ordinary functions."""
|
||||
|
||||
@@ -241,34 +242,60 @@ class ApiCall:
|
||||
else:
|
||||
md_tag_end = "```"
|
||||
for tag in error_md_tags:
|
||||
all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
|
||||
all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
|
||||
all_context = all_context.replace(tag + " " +api_context + " " + md_tag_end, api_context)
|
||||
all_context = all_context.replace(
|
||||
tag + api_context + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + "\n" + api_context + "\n" + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
tag + " " + api_context + " " + md_tag_end, api_context
|
||||
)
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
return all_context
|
||||
|
||||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||||
error_mk_tags = ["```", "```python", "```xml"]
|
||||
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
|
||||
call_context_map = extract_content_open_ending(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in call_context_map.items():
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is not None:
|
||||
if display_mode:
|
||||
if api_status.api_result:
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, api_status.api_result)
|
||||
all_context = self.__deal_error_md_tags(
|
||||
all_context, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context, api_status.api_result
|
||||
)
|
||||
else:
|
||||
if api_status.status == Status.FAILED.value:
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
|
||||
all_context = self.__deal_error_md_tags(
|
||||
all_context, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """,
|
||||
)
|
||||
else:
|
||||
cost = (api_status.end_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
|
||||
all_context = self.__deal_error_md_tags(
|
||||
all_context, api_context
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
|
||||
)
|
||||
else:
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context, False)
|
||||
all_context = all_context.replace(api_context, self.to_view_text(api_status))
|
||||
all_context = self.__deal_error_md_tags(
|
||||
all_context, api_context, False
|
||||
)
|
||||
all_context = all_context.replace(
|
||||
api_context, self.to_view_text(api_status)
|
||||
)
|
||||
|
||||
else:
|
||||
# not ready api call view change
|
||||
@@ -276,27 +303,34 @@ class ApiCall:
|
||||
cost = (now_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
for tag in error_mk_tags:
|
||||
all_context = all_context.replace(tag + api_context , api_context)
|
||||
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
all_context = all_context.replace(
|
||||
api_context,
|
||||
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
|
||||
)
|
||||
|
||||
return all_context
|
||||
|
||||
def update_from_context(self, all_context):
|
||||
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
|
||||
api_context_map = extract_content(
|
||||
all_context, self.agent_prefix, self.agent_end, True
|
||||
)
|
||||
for api_index, api_context in api_context_map.items():
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
api_call_element = ET.fromstring(api_context)
|
||||
api_name = api_call_element.find('name').text
|
||||
if api_name.find("[")>=0 or api_name.find("]")>=0:
|
||||
api_name = api_call_element.find("name").text
|
||||
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
|
||||
api_name = api_name.replace("[", "").replace("]", "")
|
||||
api_args = {}
|
||||
args_elements = api_call_element.find('args')
|
||||
args_elements = api_call_element.find("args")
|
||||
for child_element in args_elements.iter():
|
||||
api_args[child_element.tag] = child_element.text
|
||||
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is None:
|
||||
api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
|
||||
api_status = PluginStatus(
|
||||
name=api_name, location=[api_index], args=api_args
|
||||
)
|
||||
self.plugin_status_map[api_context] = api_status
|
||||
else:
|
||||
api_status.location.append(api_index)
|
||||
@@ -304,20 +338,20 @@ class ApiCall:
|
||||
def __to_view_param_str(self, api_status):
|
||||
param = {}
|
||||
if api_status.name:
|
||||
param['name'] = api_status.name
|
||||
param['status'] = api_status.status
|
||||
param["name"] = api_status.name
|
||||
param["status"] = api_status.status
|
||||
if api_status.logo_url:
|
||||
param['logo'] = api_status.logo_url
|
||||
param["logo"] = api_status.logo_url
|
||||
|
||||
if api_status.err_msg:
|
||||
param['err_msg'] = api_status.err_msg
|
||||
param["err_msg"] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param['result'] = api_status.api_result
|
||||
param["result"] = api_status.api_result
|
||||
return json.dumps(param)
|
||||
|
||||
def to_view_text(self, api_status: PluginStatus):
|
||||
api_call_element = ET.Element('dbgpt-view')
|
||||
api_call_element = ET.Element("dbgpt-view")
|
||||
api_call_element.text = self.__to_view_param_str(api_status)
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
@@ -332,7 +366,9 @@ class ApiCall:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"插件执行:{value.name},{value.args}")
|
||||
try:
|
||||
value.api_result = execute_command(value.name, value.args, self.plugin_generator)
|
||||
value.api_result = execute_command(
|
||||
value.name, value.args, self.plugin_generator
|
||||
)
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
@@ -350,15 +386,19 @@ class ApiCall:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"sql展示执行:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args['sql']
|
||||
sql = value.args["sql"]
|
||||
if sql:
|
||||
param = {
|
||||
"df": sql_run_func(sql),
|
||||
}
|
||||
if self.display_registry.is_valid_command(value.name):
|
||||
value.api_result = self.display_registry.call(value.name, **param)
|
||||
value.api_result = self.display_registry.call(
|
||||
value.name, **param
|
||||
)
|
||||
else:
|
||||
value.api_result = self.display_registry.call("response_table", **param)
|
||||
value.api_result = self.display_registry.call(
|
||||
"response_table", **param
|
||||
)
|
||||
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
@@ -366,4 +406,3 @@ class ApiCall:
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text, True)
|
||||
|
||||
|
@@ -15,6 +15,7 @@ from matplotlib.font_manager import FontManager
|
||||
from pilot.common.string_utils import is_scientific_notation
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -88,12 +89,13 @@ def zh_font_set():
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
|
||||
def format_axis(value, pos):
|
||||
# 判断是否为数字
|
||||
if is_scientific_notation(value):
|
||||
# 判断是否需要进行非科学计数法格式化
|
||||
|
||||
return '{:.2f}'.format(value)
|
||||
return "{:.2f}".format(value)
|
||||
return value
|
||||
|
||||
|
||||
@@ -102,7 +104,7 @@ def format_axis(value, pos):
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart( df: DataFrame) -> str:
|
||||
def response_line_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
@@ -143,9 +145,15 @@ def response_line_chart( df: DataFrame) -> str:
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=num_colmns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
|
||||
)
|
||||
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
@@ -154,7 +162,7 @@ def response_line_chart( df: DataFrame) -> str:
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
@@ -168,7 +176,7 @@ def response_line_chart( df: DataFrame) -> str:
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart( df: DataFrame) -> str:
|
||||
def response_bar_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
@@ -246,7 +254,7 @@ def response_bar_chart( df: DataFrame) -> str:
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100,transparent=True)
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
@@ -3,8 +3,10 @@ from pandas import DataFrame
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||
|
@@ -3,6 +3,7 @@ from pandas import DataFrame
|
||||
from pilot.base_modules.agent.commands.command_mange import command
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ def response_data_text(df: DataFrame) -> str:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = html.replace("\n", " ")
|
||||
text_info = html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
|
@@ -133,5 +133,3 @@ class PluginPromptGenerator:
|
||||
|
||||
def generate_commands_string(self) -> str:
|
||||
return f"{self._generate_numbered_list(self.commands, item_type='command')}"
|
||||
|
||||
|
||||
|
@@ -5,11 +5,12 @@ class PluginStorageType(Enum):
|
||||
Git = "git"
|
||||
Oss = "oss"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
TODO = "todo"
|
||||
RUNNING = 'running'
|
||||
FAILED = 'failed'
|
||||
COMPLETED = 'completed'
|
||||
RUNNING = "running"
|
||||
FAILED = "failed"
|
||||
COMPLETED = "completed"
|
||||
|
||||
|
||||
class ApiTagType(Enum):
|
||||
|
@@ -16,7 +16,13 @@ from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
)
|
||||
|
||||
from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
|
||||
from .model import (
|
||||
PluginHubParam,
|
||||
PagenationFilter,
|
||||
PagenationResult,
|
||||
PluginHubFilter,
|
||||
MyPluginFilter,
|
||||
)
|
||||
from .hub.agent_hub import AgentHub
|
||||
from .db.plugin_hub_db import PluginHubEntity
|
||||
from .plugins_util import scan_plugins
|
||||
@@ -33,17 +39,18 @@ class ModuleAgent(BaseComponent, ABC):
|
||||
name = ComponentType.AGENT_HUB
|
||||
|
||||
def __init__(self):
|
||||
#load plugins
|
||||
# load plugins
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
|
||||
|
||||
|
||||
def refresh_plugins(self):
|
||||
self.plugins = scan_plugins(PLUGINS_DIR)
|
||||
|
||||
def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
|
||||
def load_select_plugin(
|
||||
self, generator: PluginPromptGenerator, select_plugins: List[str]
|
||||
) -> PluginPromptGenerator:
|
||||
logger.info(f"load_select_plugin:{select_plugins}")
|
||||
# load select plugin
|
||||
for plugin in self.plugins:
|
||||
@@ -53,6 +60,7 @@ class ModuleAgent(BaseComponent, ABC):
|
||||
generator = plugin.post_prompt(generator)
|
||||
return generator
|
||||
|
||||
|
||||
module_agent = ModuleAgent()
|
||||
|
||||
|
||||
@@ -61,25 +69,28 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
|
||||
logger.info(f"agent_hub_update:{update_param.__dict__}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
|
||||
agent_hub.refresh_hub_from_git(
|
||||
update_param.url, update_param.branch, update_param.authorization
|
||||
)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
logger.error("Agent Hub Update Error!", e)
|
||||
return Result.faild(code="E0020", msg=f"Agent Hub Update Error! {e}")
|
||||
|
||||
|
||||
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
logger.info(f"get_agent_list:{filter.__dict__}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
filter_enetity:PluginHubEntity = PluginHubEntity()
|
||||
filter_enetity: PluginHubEntity = PluginHubEntity()
|
||||
if filter.filter:
|
||||
attrs = vars(filter.filter) # 获取原始对象的属性字典
|
||||
for attr, value in attrs.items():
|
||||
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
|
||||
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
|
||||
datas, total_pages, total_count = agent_hub.hub_dao.list(
|
||||
filter_enetity, filter.page_index, filter.page_size
|
||||
)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
result.page_index = filter.page_index
|
||||
result.page_size = filter.page_size
|
||||
@@ -89,11 +100,12 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
# print(json.dumps(result.to_dic()))
|
||||
return Result.succ(result.to_dic())
|
||||
|
||||
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
async def my_agents(user:str= None):
|
||||
async def my_agents(user: str = None):
|
||||
logger.info(f"my_agents:{user}")
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
agents = agent_hub.get_my_plugin(user)
|
||||
agents = agent_hub.get_my_plugin(user)
|
||||
agent_dicts = []
|
||||
for agent in agents:
|
||||
agent_dicts.append(agent.__dict__)
|
||||
@@ -102,7 +114,7 @@ async def my_agents(user:str= None):
|
||||
|
||||
|
||||
@router.post("/v1/agent/install", response_model=Result[str])
|
||||
async def agent_install(plugin_name:str, user: str = None):
|
||||
async def agent_install(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_install:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
@@ -111,14 +123,13 @@ async def agent_install(plugin_name:str, user: str = None):
|
||||
module_agent.refresh_plugins()
|
||||
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error("Plugin Install Error!", e)
|
||||
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
|
||||
|
||||
|
||||
|
||||
@router.post("/v1/agent/uninstall", response_model=Result[str])
|
||||
async def agent_uninstall(plugin_name:str, user: str = None):
|
||||
async def agent_uninstall(plugin_name: str, user: str = None):
|
||||
logger.info(f"agent_uninstall:{plugin_name},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
@@ -126,19 +137,18 @@ async def agent_uninstall(plugin_name:str, user: str = None):
|
||||
|
||||
module_agent.refresh_plugins()
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error("Plugin Uninstall Error!", e)
|
||||
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
|
||||
|
||||
|
||||
@router.post("/v1/personal/agent/upload", response_model=Result[str])
|
||||
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
|
||||
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
|
||||
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
|
||||
try:
|
||||
agent_hub = AgentHub(PLUGINS_DIR)
|
||||
await agent_hub.upload_my_plugin(doc_file, user)
|
||||
return Result.succ(None)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error("Upload Personal Plugin Error!", e)
|
||||
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")
|
||||
|
||||
|
@@ -9,30 +9,33 @@ from pilot.base_modules.meta_data.base_dao import BaseDao
|
||||
from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
|
||||
|
||||
|
||||
class MyPluginEntity(Base):
|
||||
__tablename__ = 'my_plugin'
|
||||
__tablename__ = "my_plugin"
|
||||
|
||||
id = Column(Integer, primary_key=True, comment="autoincrement id")
|
||||
tenant = Column(String(255), nullable=True, comment="user's tenant")
|
||||
user_code = Column(String(255), nullable=False, comment="user code")
|
||||
user_name = Column(String(255), nullable=True, comment="user name")
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
|
||||
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
|
||||
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
|
||||
__table_args__ = (
|
||||
UniqueConstraint('user_code','name', name="uk_name"),
|
||||
file_name = Column(String(255), nullable=False, comment="plugin package file name")
|
||||
type = Column(String(255), comment="plugin type")
|
||||
version = Column(String(255), comment="plugin version")
|
||||
use_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total use count"
|
||||
)
|
||||
succ_count = Column(
|
||||
Integer, nullable=True, default=0, comment="plugin total success count"
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime, default=datetime.utcnow, comment="plugin install time"
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("user_code", "name", name="uk_name"),)
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database="dbgpt", orm_base=Base, db_engine =engine , session= session
|
||||
database="dbgpt", orm_base=Base, db_engine=engine, session=session
|
||||
)
|
||||
|
||||
def add(self, engity: MyPluginEntity):
|
||||
@@ -60,87 +63,61 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
session.commit()
|
||||
return updated.id
|
||||
|
||||
def get_by_user(self, user: str)->list[MyPluginEntity]:
|
||||
def get_by_user(self, user: str) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == user
|
||||
)
|
||||
if user:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
|
||||
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
all_count = my_plugins.count()
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.name == query.name
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.tenant == query.tenant
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.type == query.type
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == query.user_code
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_name == query.user_name
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
|
||||
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
|
||||
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
|
||||
result = my_plugins.all()
|
||||
session.close()
|
||||
total_pages = all_count // page_size
|
||||
if all_count % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
|
||||
return result, total_pages, all_count
|
||||
|
||||
|
||||
def count(self, query: MyPluginEntity):
|
||||
session = self.get_session()
|
||||
my_plugins = session.query(func.count(MyPluginEntity.id))
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.name == query.name
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.type == query.type
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
|
||||
if query.tenant is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.tenant == query.tenant
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
|
||||
if query.user_code is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_code == query.user_code
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
|
||||
if query.user_name is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.user_name == query.user_name
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
|
||||
count = my_plugins.scalar()
|
||||
session.close()
|
||||
return count
|
||||
|
||||
|
||||
def delete(self, plugin_id: int):
|
||||
session = self.get_session()
|
||||
if plugin_id is None:
|
||||
@@ -148,9 +125,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
|
||||
query = MyPluginEntity(id=plugin_id)
|
||||
my_plugins = session.query(MyPluginEntity)
|
||||
if query.id is not None:
|
||||
my_plugins = my_plugins.filter(
|
||||
MyPluginEntity.id == query.id
|
||||
)
|
||||
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
|
||||
my_plugins.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
@@ -10,8 +10,10 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
|
||||
|
||||
|
||||
class PluginHubEntity(Base):
|
||||
__tablename__ = 'plugin_hub'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
__tablename__ = "plugin_hub"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
|
||||
description = Column(String(255), nullable=False, comment="plugin description")
|
||||
author = Column(String(255), nullable=True, comment="plugin author")
|
||||
@@ -25,8 +27,8 @@ class PluginHubEntity(Base):
|
||||
installed = Column(Integer, default=False, comment="plugin already installed count")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('name', name="uk_name"),
|
||||
Index('idx_q_type', 'type'),
|
||||
UniqueConstraint("name", name="uk_name"),
|
||||
Index("idx_q_type", "type"),
|
||||
)
|
||||
|
||||
|
||||
@@ -38,7 +40,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
|
||||
def add(self, engity: PluginHubEntity):
|
||||
session = self.get_session()
|
||||
timezone = pytz.timezone('Asia/Shanghai')
|
||||
timezone = pytz.timezone("Asia/Shanghai")
|
||||
plugin_hub = PluginHubEntity(
|
||||
name=engity.name,
|
||||
author=engity.author,
|
||||
@@ -64,7 +66,9 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
|
||||
def list(
|
||||
self, query: PluginHubEntity, page=1, page_size=20
|
||||
) -> list[PluginHubEntity]:
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
all_count = plugin_hubs.count()
|
||||
@@ -72,17 +76,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == query.name
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.type == query.type
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.author == query.author
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
@@ -110,9 +108,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
def get_by_name(self, name: str) -> PluginHubEntity:
|
||||
session = self.get_session()
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == name
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
|
||||
result = plugin_hubs.first()
|
||||
session.close()
|
||||
return result
|
||||
@@ -123,17 +119,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
if query.id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
|
||||
if query.name is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.name == query.name
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
|
||||
if query.type is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.type == query.type
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
|
||||
if query.author is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.author == query.author
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
|
||||
if query.storage_channel is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.storage_channel == query.storage_channel
|
||||
@@ -148,9 +138,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
|
||||
raise Exception("plugin_id is None")
|
||||
plugin_hubs = session.query(PluginHubEntity)
|
||||
if plugin_id is not None:
|
||||
plugin_hubs = plugin_hubs.filter(
|
||||
PluginHubEntity.id == plugin_id
|
||||
)
|
||||
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
|
||||
plugin_hubs.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
@@ -4,7 +4,7 @@ import os
|
||||
import glob
|
||||
import shutil
|
||||
from fastapi import UploadFile
|
||||
from typing import Any
|
||||
from typing import Any
|
||||
import tempfile
|
||||
|
||||
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
|
||||
@@ -38,7 +38,9 @@ class AgentHub:
|
||||
download_param = json.loads(plugin_entity.download_param)
|
||||
branch_name = download_param.get("branch_name")
|
||||
authorization = download_param.get("authorization")
|
||||
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
|
||||
file_name = self.__download_from_git(
|
||||
plugin_entity.storage_url, branch_name, authorization
|
||||
)
|
||||
|
||||
# add to my plugins and edit hub status
|
||||
plugin_entity.installed = plugin_entity.installed + 1
|
||||
@@ -65,7 +67,9 @@ class AgentHub:
|
||||
logger.error("install pluguin exception!", e)
|
||||
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
|
||||
else:
|
||||
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
|
||||
raise ValueError(
|
||||
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Can't Find Plugin {plugin_name}!")
|
||||
|
||||
@@ -75,7 +79,9 @@ class AgentHub:
|
||||
plugin_entity.installed = plugin_entity.installed - 1
|
||||
with self.hub_dao.get_session() as session:
|
||||
try:
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
|
||||
my_plugin_q = session.query(MyPluginEntity).filter(
|
||||
MyPluginEntity.name == plugin_name
|
||||
)
|
||||
if user:
|
||||
my_plugin_q.filter(MyPluginEntity.user_code == user)
|
||||
my_plugin_q.delete()
|
||||
@@ -92,10 +98,10 @@ class AgentHub:
|
||||
have_installed = True
|
||||
break
|
||||
if not have_installed:
|
||||
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
|
||||
files = glob.glob(
|
||||
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
|
||||
plugin_repo_name = (
|
||||
plugin_entity.storage_url.replace(".git", "").strip("/").split("/")[-1]
|
||||
)
|
||||
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
|
||||
@@ -109,9 +115,16 @@ class AgentHub:
|
||||
my_plugin_entity.version = hub_plugin.version
|
||||
return my_plugin_entity
|
||||
|
||||
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
|
||||
def refresh_hub_from_git(
|
||||
self,
|
||||
github_repo: str = None,
|
||||
branch_name: str = None,
|
||||
authorization: str = None,
|
||||
):
|
||||
logger.info("refresh_hub_by_git start!")
|
||||
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
|
||||
update_from_git(
|
||||
self.temp_hub_file_path, github_repo, branch_name, authorization
|
||||
)
|
||||
git_plugins = scan_plugins(self.temp_hub_file_path)
|
||||
try:
|
||||
for git_plugin in git_plugins:
|
||||
@@ -123,13 +136,13 @@ class AgentHub:
|
||||
plugin_hub_info.type = ""
|
||||
plugin_hub_info.storage_channel = PluginStorageType.Git.value
|
||||
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
|
||||
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
|
||||
plugin_hub_info.email = getattr(git_plugin, '_email', '')
|
||||
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
|
||||
plugin_hub_info.email = getattr(git_plugin, "_email", "")
|
||||
download_param = {}
|
||||
if branch_name:
|
||||
download_param['branch_name'] = branch_name
|
||||
download_param["branch_name"] = branch_name
|
||||
if authorization and len(authorization) > 0:
|
||||
download_param['authorization'] = authorization
|
||||
download_param["authorization"] = authorization
|
||||
plugin_hub_info.download_param = json.dumps(download_param)
|
||||
plugin_hub_info.installed = 0
|
||||
|
||||
@@ -140,15 +153,12 @@ class AgentHub:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
|
||||
|
||||
async def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
|
||||
|
||||
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
|
||||
# We can not move temp file in windows system when we open file in context of `with`
|
||||
file_path = os.path.join(self.plugin_dir, doc_file.filename)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(
|
||||
dir=os.path.join(self.plugin_dir)
|
||||
)
|
||||
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
|
||||
with os.fdopen(tmp_fd, "wb") as tmp:
|
||||
tmp.write(await doc_file.read())
|
||||
shutil.move(
|
||||
@@ -158,14 +168,14 @@ class AgentHub:
|
||||
|
||||
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
|
||||
|
||||
if user is None or len(user) <=0:
|
||||
if user is None or len(user) <= 0:
|
||||
user = Default_User
|
||||
|
||||
for my_plugin in my_plugins:
|
||||
my_plugin_entiy = MyPluginEntity()
|
||||
|
||||
my_plugin_entiy.name = my_plugin._name
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.version = my_plugin._version
|
||||
my_plugin_entiy.type = "Personal"
|
||||
my_plugin_entiy.user_code = user
|
||||
my_plugin_entiy.user_name = user
|
||||
@@ -183,4 +193,3 @@ class AgentHub:
|
||||
if not user:
|
||||
user = Default_User
|
||||
return self.my_lugin_dao.get_by_user(user)
|
||||
|
||||
|
@@ -3,32 +3,35 @@ from dataclasses import dataclass
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Generic, Any
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PagenationFilter(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
page_size: int = 20
|
||||
filter: T = None
|
||||
|
||||
|
||||
class PagenationResult(BaseModel, Generic[T]):
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
page_size: int = 20
|
||||
total_page: int = 0
|
||||
total_row_count: int = 0
|
||||
datas: List[T] = []
|
||||
|
||||
def to_dic(self):
|
||||
data_dicts =[]
|
||||
data_dicts = []
|
||||
for item in self.datas:
|
||||
data_dicts.append(item.__dict__)
|
||||
return {
|
||||
'page_index': self.page_index,
|
||||
'page_size': self.page_size,
|
||||
'total_page': self.total_page,
|
||||
'total_row_count': self.total_row_count,
|
||||
'datas': data_dicts
|
||||
"page_index": self.page_index,
|
||||
"page_size": self.page_size,
|
||||
"total_page": self.total_page,
|
||||
"total_row_count": self.total_row_count,
|
||||
"datas": data_dicts,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginHubFilter(BaseModel):
|
||||
name: str
|
||||
@@ -53,9 +56,14 @@ class MyPluginFilter(BaseModel):
|
||||
|
||||
|
||||
class PluginHubParam(BaseModel):
|
||||
channel: Optional[str] = Field("git", description="Plugin storage channel")
|
||||
url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
|
||||
branch: Optional[str] = Field("main", description="github download branch", nullable=True)
|
||||
authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)
|
||||
|
||||
|
||||
channel: Optional[str] = Field("git", description="Plugin storage channel")
|
||||
url: Optional[str] = Field(
|
||||
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
|
||||
description="Plugin storage url",
|
||||
)
|
||||
branch: Optional[str] = Field(
|
||||
"main", description="github download branch", nullable=True
|
||||
)
|
||||
authorization: Optional[str] = Field(
|
||||
None, description="github download authorization", nullable=True
|
||||
)
|
||||
|
@@ -117,7 +117,7 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
|
||||
def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
logger.info(f"__scan_plugin_file:{file_path},{debug}")
|
||||
loaded_plugins = []
|
||||
if moduleList := inspect_zip_for_modules(str(file_path), debug):
|
||||
@@ -133,14 +133,17 @@ def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTempl
|
||||
a_module = getattr(zipped_module, key)
|
||||
a_keys = dir(a_module)
|
||||
if (
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
"_abc_impl" in a_keys
|
||||
and a_module.__name__ != "AutoGPTPluginTemplate"
|
||||
# and denylist_allowlist_check(a_module.__name__, cfg)
|
||||
):
|
||||
loaded_plugins.append(a_module())
|
||||
return loaded_plugins
|
||||
|
||||
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
|
||||
def scan_plugins(
|
||||
plugins_file_path: str, file_name: str = "", debug: bool = False
|
||||
) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
@@ -159,7 +162,7 @@ def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = Fals
|
||||
loaded_plugins = __scan_plugin_file(plugin_path)
|
||||
else:
|
||||
for plugin_path in plugins_path.glob("*.zip"):
|
||||
loaded_plugins.extend(__scan_plugin_file(plugin_path))
|
||||
loaded_plugins.extend(__scan_plugin_file(plugin_path))
|
||||
|
||||
if loaded_plugins:
|
||||
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
|
||||
@@ -192,17 +195,23 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
|
||||
return ack.lower() == cfg.authorise_key
|
||||
|
||||
|
||||
def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
|
||||
authorization: str = None):
|
||||
def update_from_git(
|
||||
download_path: str,
|
||||
github_repo: str = "",
|
||||
branch_name: str = "main",
|
||||
authorization: str = None,
|
||||
):
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
if github_repo:
|
||||
if github_repo.index("github.com") <= 0:
|
||||
raise ValueError("Not a correct Github repository address!" + github_repo)
|
||||
github_repo = github_repo.replace(".git", "")
|
||||
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
|
||||
plugin_repo_name = github_repo.strip('/').split('/')[-1]
|
||||
plugin_repo_name = github_repo.strip("/").split("/")[-1]
|
||||
else:
|
||||
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
|
||||
url = (
|
||||
"https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
|
||||
)
|
||||
plugin_repo_name = "DB-GPT-Plugins"
|
||||
try:
|
||||
session = requests.Session()
|
||||
@@ -216,14 +225,14 @@ def update_from_git(download_path: str, github_repo: str = "", branch_name: str
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(download_path)
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
|
||||
)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*"))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
|
||||
file_name = (
|
||||
f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
|
||||
)
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
class ModuleMangeApi:
|
||||
|
||||
def module_name(self):
|
||||
pass
|
||||
|
||||
|
@@ -1,11 +1,16 @@
|
||||
from typing import TypeVar, Generic, List, Any
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseDao(Generic[T]):
|
||||
def __init__(
|
||||
self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
|
||||
self,
|
||||
orm_base=None,
|
||||
database: str = None,
|
||||
db_engine: Any = None,
|
||||
session: Any = None,
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
|
@@ -7,7 +7,7 @@ import fnmatch
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from sqlalchemy import create_engine,DateTime, String, func, MetaData
|
||||
from sqlalchemy import create_engine, DateTime, String, func, MetaData
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped
|
||||
@@ -32,16 +32,17 @@ db_name = "dbgpt"
|
||||
db_path = default_db_path + f"/{db_name}.db"
|
||||
connection = sqlite3.connect(db_path)
|
||||
|
||||
if CFG.LOCAL_DB_TYPE == 'mysql':
|
||||
engine_temp = create_engine(f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
)
|
||||
if CFG.LOCAL_DB_TYPE == "mysql":
|
||||
engine_temp = create_engine(
|
||||
f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
)
|
||||
# check and auto create mysqldatabase
|
||||
try:
|
||||
# try to connect
|
||||
@@ -53,20 +54,19 @@ if CFG.LOCAL_DB_TYPE == 'mysql':
|
||||
# if connect failed, create dbgpt database
|
||||
logger.error(f"{db_name} not connect success!")
|
||||
|
||||
engine = create_engine(f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ f"/{db_name}"
|
||||
)
|
||||
engine = create_engine(
|
||||
f"mysql+pymysql://"
|
||||
+ quote(CFG.LOCAL_DB_USER)
|
||||
+ ":"
|
||||
+ quote(CFG.LOCAL_DB_PASSWORD)
|
||||
+ "@"
|
||||
+ CFG.LOCAL_DB_HOST
|
||||
+ ":"
|
||||
+ str(CFG.LOCAL_DB_PORT)
|
||||
+ f"/{db_name}"
|
||||
)
|
||||
else:
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
|
||||
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
@@ -81,16 +81,16 @@ Base = declarative_base()
|
||||
alembic_ini_path = default_db_path + "/alembic.ini"
|
||||
alembic_cfg = AlembicConfig(alembic_ini_path)
|
||||
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
|
||||
|
||||
os.makedirs(default_db_path + "/alembic", exist_ok=True)
|
||||
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
|
||||
|
||||
alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
|
||||
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
|
||||
|
||||
# 将模型和会话传递给Alembic配置
|
||||
alembic_cfg.attributes['target_metadata'] = Base.metadata
|
||||
alembic_cfg.attributes['session'] = session
|
||||
alembic_cfg.attributes["target_metadata"] = Base.metadata
|
||||
alembic_cfg.attributes["session"] = session
|
||||
|
||||
|
||||
# # 创建表
|
||||
@@ -106,7 +106,7 @@ def ddl_init_and_upgrade():
|
||||
# command.upgrade(alembic_cfg, 'head')
|
||||
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
|
||||
with engine.connect() as connection:
|
||||
alembic_cfg.attributes['connection'] = connection
|
||||
alembic_cfg.attributes["connection"] = connection
|
||||
heads = command.heads(alembic_cfg)
|
||||
print("heads:" + str(heads))
|
||||
|
||||
|
@@ -1,2 +1 @@
|
||||
|
||||
|
||||
|
@@ -1,27 +1,30 @@
|
||||
import re
|
||||
|
||||
|
||||
def is_all_chinese(text):
|
||||
### Determine whether the string is pure Chinese
|
||||
pattern = re.compile(r'^[一-龥]+$')
|
||||
pattern = re.compile(r"^[一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_number_chinese(text):
|
||||
### Determine whether the string is numbers and Chinese
|
||||
pattern = re.compile(r'^[\d一-龥]+$')
|
||||
pattern = re.compile(r"^[\d一-龥]+$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_chinese_include_number(text):
|
||||
### Determine whether the string is pure Chinese or Chinese containing numbers
|
||||
pattern = re.compile(r'^[一-龥]+[\d一-龥]*$')
|
||||
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def is_scientific_notation(string):
|
||||
# 科学计数法的正则表达式
|
||||
pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$'
|
||||
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
|
||||
# 使用正则表达式匹配字符串
|
||||
match = re.match(pattern, str(string))
|
||||
# 判断是否匹配成功
|
||||
@@ -30,28 +33,30 @@ def is_scientific_notation(string):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def extract_content(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text
|
||||
match_map ={}
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text open ending
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
while start_index != -1:
|
||||
if long_string.find(s2, start_index) <=0:
|
||||
if long_string.find(s2, start_index) <= 0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
if is_include:
|
||||
@@ -59,17 +64,16 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
if is_include:
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
extracted_content = long_string[start_index : end_index + len(s2)]
|
||||
else:
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
extracted_content = long_string[start_index + len(s1) : end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index= long_string.find(s1, start_index + 1)
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
if __name__ == "__main__":
|
||||
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
|
||||
s1 = "123"
|
||||
s2 = "456"
|
||||
|
@@ -60,7 +60,9 @@ class Config(metaclass=Singleton):
|
||||
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
|
||||
if self.zhipu_proxy_api_key:
|
||||
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
|
||||
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION")
|
||||
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
|
||||
"ZHIPU_MODEL_VERSION"
|
||||
)
|
||||
|
||||
# wenxin
|
||||
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
|
||||
@@ -68,7 +70,9 @@ class Config(metaclass=Singleton):
|
||||
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
|
||||
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
|
||||
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
|
||||
os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret
|
||||
os.environ[
|
||||
"wenxin_proxyllm_proxy_api_secret"
|
||||
] = self.wenxin_proxy_api_secret
|
||||
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
|
||||
|
||||
# xunfei spark
|
||||
@@ -91,7 +95,6 @@ class Config(metaclass=Singleton):
|
||||
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
|
||||
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
|
||||
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
@@ -172,7 +175,6 @@ class Config(metaclass=Singleton):
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
###dbgpt meta info database connection configuration
|
||||
@@ -190,7 +192,6 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
|
||||
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
|
||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||
|
@@ -4,10 +4,13 @@ from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
|
||||
class ConnectConfigEntity(Base):
|
||||
__tablename__ = 'connect_config'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
db_type = Column(String(255), nullable=False, comment="db type")
|
||||
__tablename__ = "connect_config"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
db_type = Column(String(255), nullable=False, comment="db type")
|
||||
db_name = Column(String(255), nullable=False, comment="db name")
|
||||
db_path = Column(String(255), nullable=True, comment="file db path")
|
||||
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
|
||||
@@ -17,8 +20,8 @@ class ConnectConfigEntity(Base):
|
||||
comment = Column(Text, nullable=True, comment="db comment")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('db_name', name="uk_db"),
|
||||
Index('idx_q_db_type', 'db_type'),
|
||||
UniqueConstraint("db_name", name="uk_db"),
|
||||
Index("idx_q_db_type", "db_type"),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,9 +46,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
raise Exception("db_name is None")
|
||||
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(
|
||||
ConnectConfigEntity.db_name == db_name
|
||||
)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
db_connect.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
@@ -53,10 +54,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
def get_by_name(self, db_name: str) -> ConnectConfigEntity:
|
||||
session = self.get_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(
|
||||
ConnectConfigEntity.db_name == db_name
|
||||
)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
result = db_connect.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
|
@@ -5,15 +5,17 @@ from typing import List
|
||||
from enum import Enum
|
||||
from pilot.scene.message import OnceConversation
|
||||
|
||||
|
||||
class MemoryStoreType(Enum):
|
||||
File= 'file'
|
||||
Memory = 'memory'
|
||||
DB = 'db'
|
||||
DuckDb = 'duckdb'
|
||||
File = "file"
|
||||
Memory = "memory"
|
||||
DB = "db"
|
||||
DuckDb = "duckdb"
|
||||
|
||||
|
||||
class BaseChatHistoryMemory(ABC):
|
||||
store_type: MemoryStoreType
|
||||
|
||||
def __init__(self):
|
||||
self.conversations: List[OnceConversation] = []
|
||||
|
||||
|
@@ -4,9 +4,7 @@ from pilot.configs.config import Config
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
||||
class ChatHistory:
|
||||
|
||||
def __init__(self):
|
||||
self.memory_type = MemoryStoreType.DB.value
|
||||
self.mem_store_class_map = {}
|
||||
@@ -14,15 +12,16 @@ class ChatHistory:
|
||||
from .store_type.file_history import FileHistoryMemory
|
||||
from .store_type.meta_db_history import DbHistoryMemory
|
||||
from .store_type.mem_history import MemHistoryMemory
|
||||
|
||||
self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
|
||||
self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
|
||||
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
|
||||
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
|
||||
|
||||
|
||||
def get_store_instance(self, chat_session_id):
|
||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
|
||||
|
||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
|
||||
chat_session_id
|
||||
)
|
||||
|
||||
def get_store_cls(self):
|
||||
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)
|
||||
|
@@ -4,20 +4,28 @@ from typing import List
|
||||
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
|
||||
class ChatHistoryEntity(Base):
|
||||
__tablename__ = 'chat_history'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
|
||||
conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
|
||||
__tablename__ = "chat_history"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
conv_uid = Column(
|
||||
String(255),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
comment="Conversation record unique id",
|
||||
)
|
||||
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
|
||||
summary = Column(String(255), nullable=False, comment="Conversation record summary")
|
||||
user_name = Column(String(255), nullable=True, comment="interlocutor")
|
||||
messages = Column(Text, nullable=True, comment="Conversation details")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('conv_uid', name="uk_conversation"),
|
||||
Index('idx_q_user', 'user_name'),
|
||||
Index('idx_q_mode', 'chat_mode'),
|
||||
Index('idx_q_conv', 'summary'),
|
||||
UniqueConstraint("conv_uid", name="uk_conversation"),
|
||||
Index("idx_q_user", "user_name"),
|
||||
Index("idx_q_mode", "chat_mode"),
|
||||
Index("idx_q_conv", "summary"),
|
||||
)
|
||||
|
||||
|
||||
@@ -31,9 +39,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||
session = self.get_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
if user_name:
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.user_name == user_name
|
||||
)
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
|
||||
|
||||
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
|
||||
|
||||
@@ -50,13 +56,11 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def update_message_by_uid(self, message: str, conv_uid:str):
|
||||
def update_message_by_uid(self, message: str, conv_uid: str):
|
||||
session = self.get_session()
|
||||
try:
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||
updated = chat_history.update({ChatHistoryEntity.messages: message})
|
||||
session.commit()
|
||||
return updated.id
|
||||
@@ -69,9 +73,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||
raise Exception("conv_uid is None")
|
||||
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||
chat_history.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
@@ -79,10 +81,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
|
||||
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
|
||||
session = self.get_session()
|
||||
chat_history = session.query(ChatHistoryEntity)
|
||||
chat_history = chat_history.filter(
|
||||
ChatHistoryEntity.conv_uid == conv_uid
|
||||
)
|
||||
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
|
||||
result = chat_history.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
|
@@ -148,7 +148,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
return json.loads(context[0])
|
||||
return None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
if os.path.isfile(duckdb_path):
|
||||
|
@@ -17,7 +17,7 @@ CFG = Config()
|
||||
|
||||
|
||||
class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.File.value
|
||||
store_type: str = MemoryStoreType.File.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
now = datetime.datetime.now()
|
||||
@@ -49,5 +49,3 @@ class FileHistoryMemory(BaseChatHistoryMemory):
|
||||
|
||||
def clear(self) -> None:
|
||||
self.file_path.write_text(json.dumps([]))
|
||||
|
||||
|
||||
|
@@ -10,7 +10,7 @@ CFG = Config()
|
||||
|
||||
|
||||
class MemHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.Memory.value
|
||||
store_type: str = MemoryStoreType.Memory.value
|
||||
|
||||
histroies_map = FixedSizeDict(100)
|
||||
|
||||
|
@@ -12,18 +12,22 @@ from pilot.scene.message import (
|
||||
from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
|
||||
|
||||
from pilot.memory.chat_history.base import MemoryStoreType
|
||||
|
||||
CFG = Config()
|
||||
logger = logging.getLogger("db_chat_history")
|
||||
|
||||
|
||||
class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
store_type: str = MemoryStoreType.DB.value
|
||||
store_type: str = MemoryStoreType.DB.value
|
||||
|
||||
def __init__(self, chat_session_id: str):
|
||||
self.chat_seesion_id = chat_session_id
|
||||
self.chat_history_dao = ChatHistoryDao()
|
||||
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||
self.chat_seesion_id
|
||||
)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
if context:
|
||||
@@ -31,7 +35,6 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
return conversations
|
||||
return []
|
||||
|
||||
|
||||
def create(self, chat_mode, summary: str, user_name: str) -> None:
|
||||
try:
|
||||
chat_history: ChatHistoryEntity = ChatHistoryEntity()
|
||||
@@ -43,10 +46,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
except Exception as e:
|
||||
logger.error("init create conversation log error!" + str(e))
|
||||
|
||||
|
||||
def append(self, once_message: OnceConversation) -> None:
|
||||
logger.info("db history append:{}", once_message)
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
|
||||
self.chat_seesion_id
|
||||
)
|
||||
conversations: List[OnceConversation] = []
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
@@ -59,7 +63,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
chat_history.conv_uid = self.chat_seesion_id
|
||||
chat_history.chat_mode = once_message.chat_mode
|
||||
chat_history.user_name = "default"
|
||||
chat_history.summary = once_message.get_user_conv().content
|
||||
chat_history.summary = once_message.get_user_conv().content
|
||||
|
||||
conversations.append(_conversation_to_dic(once_message))
|
||||
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
|
||||
@@ -67,31 +71,28 @@ class DbHistoryMemory(BaseChatHistoryMemory):
|
||||
self.chat_history_dao.update(chat_history)
|
||||
|
||||
def update(self, messages: List[OnceConversation]) -> None:
|
||||
self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
|
||||
|
||||
self.chat_history_dao.update_message_by_uid(
|
||||
json.dumps(messages, ensure_ascii=False), self.chat_seesion_id
|
||||
)
|
||||
|
||||
def delete(self) -> bool:
|
||||
self.chat_history_dao.delete(self.chat_seesion_id)
|
||||
|
||||
|
||||
def conv_info(self, conv_uid: str = None) -> None:
|
||||
logger.info("conv_info:{}", conv_uid)
|
||||
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
|
||||
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
|
||||
return chat_history.__dict__
|
||||
|
||||
|
||||
def get_messages(self) -> List[OnceConversation]:
|
||||
logger.info("get_messages:{}", self.chat_seesion_id)
|
||||
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
|
||||
if chat_history:
|
||||
context = chat_history.messages
|
||||
return json.loads(context)
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def conv_list(cls, user_name: str = None) -> None:
|
||||
|
||||
chat_history_dao = ChatHistoryDao()
|
||||
history_list = chat_history_dao.list_last_20()
|
||||
result = []
|
||||
|
@@ -66,12 +66,14 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
if engine.dialect.name == 'sqlite':
|
||||
context.configure(connection=engine.connect(), target_metadata=target_metadata, render_as_batch=True)
|
||||
else:
|
||||
if engine.dialect.name == "sqlite":
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
connection=engine.connect(),
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
)
|
||||
else:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
@@ -145,12 +145,12 @@ def initialize_controller(
|
||||
controller.backend = LocalModelController()
|
||||
|
||||
if app:
|
||||
app.include_router(router, prefix="/api", tags=['Model'])
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api", tags=['Model'])
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
|
||||
|
||||
|
@@ -9,18 +9,21 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
|
||||
|
||||
|
||||
def _calculate_md5(text: str) -> str:
|
||||
"""Calculate md5 """
|
||||
"""Calculate md5"""
|
||||
md5 = hashlib.md5()
|
||||
md5.update(text.encode("utf-8"))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
def _sign(data: dict, secret_key: str, timestamp: str):
|
||||
data_str = json.dumps(data)
|
||||
signature = _calculate_md5(secret_key + data_str + timestamp)
|
||||
return signature
|
||||
|
||||
|
||||
def baichuan_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=4096
|
||||
):
|
||||
@@ -31,7 +34,6 @@ def baichuan_generate_stream(
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
@@ -47,11 +49,11 @@ def baichuan_generate_stream(
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": history,
|
||||
"messages": history,
|
||||
"parameters": {
|
||||
"temperature": params.get("temperature"),
|
||||
"top_k": params.get("top_k", 10)
|
||||
}
|
||||
"top_k": params.get("top_k", 10),
|
||||
},
|
||||
}
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
@@ -15,6 +15,7 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
SPARK_DEFAULT_API_VERSION = "v2"
|
||||
|
||||
|
||||
def spark_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@@ -31,7 +32,6 @@ def spark_generate_stream(
|
||||
domain = "general"
|
||||
url = "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
|
||||
history = []
|
||||
@@ -57,43 +57,39 @@ def spark_generate_stream(
|
||||
break
|
||||
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": proxy_app_id,
|
||||
"uid": params.get("request_id", 1)
|
||||
},
|
||||
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": context_len,
|
||||
"auditing": "default",
|
||||
"temperature": params.get("temperature")
|
||||
"temperature": params.get("temperature"),
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": last_user_input.get("content")
|
||||
}
|
||||
}
|
||||
"payload": {"message": {"text": last_user_input.get("content")}},
|
||||
}
|
||||
|
||||
async_call(request_url, data)
|
||||
|
||||
|
||||
async def async_call(request_url, data):
|
||||
async with websockets.connect(request_url) as ws:
|
||||
await ws.send(json.dumps(data, ensure_ascii=False))
|
||||
finish = False
|
||||
while not finish:
|
||||
chunk = ws.recv()
|
||||
chunk = ws.recv()
|
||||
response = json.loads(chunk)
|
||||
if response.get("header", {}).get("status") == 2:
|
||||
finish = True
|
||||
if text := response.get("payload", {}).get("choices", {}).get("text"):
|
||||
yield text[0]["content"]
|
||||
|
||||
class SparkAPI:
|
||||
|
||||
def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None:
|
||||
class SparkAPI:
|
||||
def __init__(
|
||||
self, appid: str, api_key: str, api_secret: str, spark_url: str
|
||||
) -> None:
|
||||
self.appid = appid
|
||||
self.api_key = api_key
|
||||
self.api_secret = api_secret
|
||||
@@ -102,9 +98,7 @@ class SparkAPI:
|
||||
|
||||
self.spark_url = spark_url
|
||||
|
||||
|
||||
def gen_url(self):
|
||||
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
@@ -112,21 +106,22 @@ class SparkAPI:
|
||||
_signature += "data: " + date + "\n"
|
||||
_signature += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
_signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
_signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"),
|
||||
_signature.encode("utf-8"),
|
||||
digestmod=hashlib.sha256,
|
||||
).digest()
|
||||
|
||||
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8")
|
||||
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
|
||||
|
||||
authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8')
|
||||
authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
|
||||
url = self.spark_url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
|
@@ -8,10 +8,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
@@ -62,14 +63,14 @@ def tongyi_generate_stream(
|
||||
messages=history,
|
||||
top_p=params.get("top_p", 0.8),
|
||||
stream=True,
|
||||
result_format='message'
|
||||
result_format="message",
|
||||
)
|
||||
|
||||
for r in res:
|
||||
if r:
|
||||
if r['status_code'] == 200:
|
||||
if r["status_code"] == 200:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield content
|
||||
else:
|
||||
content = r['code'] + ":" + r["message"]
|
||||
content = r["code"] + ":" + r["message"]
|
||||
yield content
|
||||
|
@@ -6,20 +6,26 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from cachetools import cached, TTLCache
|
||||
|
||||
|
||||
@cached(TTLCache(1, 1800))
|
||||
def _build_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
Generate Access token according AK, SK
|
||||
Generate Access token according AK, SK
|
||||
"""
|
||||
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": api_key,
|
||||
"client_secret": secret_key,
|
||||
}
|
||||
|
||||
res = requests.get(url=url, params=params)
|
||||
|
||||
if res.status_code == 200:
|
||||
return res.json().get("access_token")
|
||||
|
||||
|
||||
def wenxin_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@@ -38,12 +44,9 @@ def wenxin_generate_stream(
|
||||
proxy_api_secret = model_params.proxy_api_secret
|
||||
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||
|
||||
proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
|
||||
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
|
||||
|
||||
if not access_token:
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
@@ -86,7 +89,7 @@ def wenxin_generate_stream(
|
||||
"messages": history,
|
||||
"system": system,
|
||||
"temperature": params.get("temperature"),
|
||||
"stream": True
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
text = ""
|
||||
@@ -106,6 +109,3 @@ def wenxin_generate_stream(
|
||||
content = obj["result"]
|
||||
text += content
|
||||
yield text
|
||||
|
||||
|
||||
|
@@ -7,6 +7,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
@@ -19,6 +20,7 @@ def zhipu_generate_stream(
|
||||
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
|
||||
|
||||
import zhipuai
|
||||
|
||||
zhipuai.api_key = proxy_api_key
|
||||
history = []
|
||||
|
||||
|
@@ -297,7 +297,6 @@ async def params_load(
|
||||
|
||||
@router.post("/v1/chat/dialogue/delete")
|
||||
async def dialogue_delete(con_uid: str):
|
||||
|
||||
history_fac = ChatHistory()
|
||||
history_mem = history_fac.get_store_instance(con_uid)
|
||||
history_mem.delete()
|
||||
|
@@ -54,7 +54,7 @@ class BaseChat(ABC):
|
||||
)
|
||||
chat_history_fac = ChatHistory()
|
||||
### can configurable storage methods
|
||||
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
|
||||
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
|
||||
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(
|
||||
|
@@ -20,10 +20,11 @@ logger = logging.getLogger("chat_agent")
|
||||
class ChatAgent(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatAgent.value()
|
||||
chat_retention_rounds = 0
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
if not chat_param['select_param']:
|
||||
if not chat_param["select_param"]:
|
||||
raise ValueError("Please select a Plugin!")
|
||||
self.select_plugins = chat_param['select_param'].split(",")
|
||||
self.select_plugins = chat_param["select_param"].split(",")
|
||||
|
||||
chat_param["chat_mode"] = ChatScene.ChatAgent
|
||||
super().__init__(chat_param=chat_param)
|
||||
@@ -31,8 +32,12 @@ class ChatAgent(BaseChat):
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
|
||||
# load select plugin
|
||||
agent_module = CFG.SYSTEM_APP.get_component(ComponentType.AGENT_HUB, ModuleAgent)
|
||||
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
|
||||
agent_module = CFG.SYSTEM_APP.get_component(
|
||||
ComponentType.AGENT_HUB, ModuleAgent
|
||||
)
|
||||
self.plugins_prompt_generator = agent_module.load_select_plugin(
|
||||
self.plugins_prompt_generator, self.select_plugins
|
||||
)
|
||||
|
||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||
|
||||
@@ -53,4 +58,3 @@ class ChatAgent(BaseChat):
|
||||
|
||||
def __list_to_prompt_str(self, list: List) -> str:
|
||||
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))
|
||||
|
||||
|
@@ -54,7 +54,7 @@ _DEFAULT_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
_PROMPT_SCENE_DEFINE=(
|
||||
_PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ prompt = PromptTemplate(
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||
),
|
||||
temperature = 1
|
||||
temperature=1
|
||||
# example_selector=plugin_example,
|
||||
)
|
||||
|
||||
|
@@ -36,7 +36,7 @@ class ChatExcel(BaseChat):
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
self.api_call = ApiCall(display_registry = CFG.command_disply)
|
||||
self.api_call = ApiCall(display_registry=CFG.command_disply)
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
def _generate_numbered_list(self) -> str:
|
||||
|
@@ -44,7 +44,7 @@ _DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_SCENE_DEFINE =(
|
||||
PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
|
@@ -9,8 +9,21 @@ import pandas as pd
|
||||
import chardet
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\
|
||||
Literal, infixNotation, opAssoc, unicodeString,Regex
|
||||
from pyparsing import (
|
||||
CaselessKeyword,
|
||||
Word,
|
||||
alphas,
|
||||
alphanums,
|
||||
delimitedList,
|
||||
Forward,
|
||||
Group,
|
||||
Optional,
|
||||
Literal,
|
||||
infixNotation,
|
||||
opAssoc,
|
||||
unicodeString,
|
||||
Regex,
|
||||
)
|
||||
|
||||
from pilot.common.pd_utils import csv_colunm_foramt
|
||||
from pilot.common.string_utils import is_chinese_include_number
|
||||
@@ -21,14 +34,15 @@ def excel_colunm_format(old_name: str) -> str:
|
||||
new_column = new_column.replace(" ", "_")
|
||||
return new_column
|
||||
|
||||
|
||||
def detect_encoding(file_path):
|
||||
# 读取文件的二进制数据
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
# 使用 chardet 来检测文件编码
|
||||
result = chardet.detect(data)
|
||||
encoding = result['encoding']
|
||||
confidence = result['confidence']
|
||||
encoding = result["encoding"]
|
||||
confidence = result["confidence"]
|
||||
return encoding, confidence
|
||||
|
||||
|
||||
@@ -39,10 +53,11 @@ def add_quotes_ex(sql: str, column_names):
|
||||
sql = sql.replace(column_name, f'"{column_name}"')
|
||||
return sql
|
||||
|
||||
|
||||
def parse_sql(sql):
|
||||
# 定义关键字和标识符
|
||||
select_stmt = Forward()
|
||||
column = Regex(r'[\w一-龥]*')
|
||||
column = Regex(r"[\w一-龥]*")
|
||||
table = Word(alphanums)
|
||||
join_expr = Forward()
|
||||
where_expr = Forward()
|
||||
@@ -62,21 +77,24 @@ def parse_sql(sql):
|
||||
not_in_keyword = CaselessKeyword("NOT IN")
|
||||
|
||||
# 定义语法规则
|
||||
select_stmt <<= (select_keyword + delimitedList(column) +
|
||||
from_keyword + delimitedList(table) +
|
||||
Optional(join_expr) +
|
||||
Optional(where_keyword + where_expr) +
|
||||
Optional(group_by_keyword + group_by_expr) +
|
||||
Optional(order_by_keyword + order_by_expr))
|
||||
select_stmt <<= (
|
||||
select_keyword
|
||||
+ delimitedList(column)
|
||||
+ from_keyword
|
||||
+ delimitedList(table)
|
||||
+ Optional(join_expr)
|
||||
+ Optional(where_keyword + where_expr)
|
||||
+ Optional(group_by_keyword + group_by_expr)
|
||||
+ Optional(order_by_keyword + order_by_expr)
|
||||
)
|
||||
|
||||
join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column
|
||||
|
||||
where_expr <<= column + Literal("=") + Word(alphanums) + \
|
||||
Optional(and_keyword + where_expr) | \
|
||||
column + Literal(">") + Word(alphanums) + \
|
||||
Optional(and_keyword + where_expr) | \
|
||||
column + Literal("<") + Word(alphanums) + \
|
||||
Optional(and_keyword + where_expr)
|
||||
where_expr <<= (
|
||||
column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
| column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
| column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr)
|
||||
)
|
||||
|
||||
group_by_expr <<= delimitedList(column)
|
||||
|
||||
@@ -88,7 +106,6 @@ def parse_sql(sql):
|
||||
return parsed_result.asList()
|
||||
|
||||
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
sql = sql.replace("`", "")
|
||||
sql = sql.replace("'", "")
|
||||
@@ -108,6 +125,7 @@ def deep_quotes(token, column_names=[]):
|
||||
new_value = token.value.replace("`", "").replace("'", "")
|
||||
token.value = f'"{new_value}"'
|
||||
|
||||
|
||||
def get_select_clause(sql):
|
||||
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
|
||||
|
||||
@@ -123,6 +141,7 @@ def get_select_clause(sql):
|
||||
select_tokens.append(token)
|
||||
return "".join(str(token) for token in select_tokens)
|
||||
|
||||
|
||||
def parse_select_fields(sql):
|
||||
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
|
||||
fields = []
|
||||
@@ -139,12 +158,14 @@ def parse_select_fields(sql):
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def add_quotes_to_chinese_columns(sql, column_names=[]):
|
||||
parsed = sqlparse.parse(sql)
|
||||
for stmt in parsed:
|
||||
process_statement(stmt, column_names)
|
||||
return str(parsed[0])
|
||||
|
||||
|
||||
def process_statement(statement, column_names=[]):
|
||||
if isinstance(statement, sqlparse.sql.IdentifierList):
|
||||
for identifier in statement.get_identifiers():
|
||||
@@ -155,22 +176,23 @@ def process_statement(statement, column_names=[]):
|
||||
for item in statement.tokens:
|
||||
process_statement(item)
|
||||
|
||||
|
||||
def process_identifier(identifier, column_names=[]):
|
||||
# if identifier.has_alias():
|
||||
# alias = identifier.get_alias()
|
||||
# identifier.tokens[-1].value = '[' + alias + ']'
|
||||
if hasattr(identifier, 'tokens') and identifier.value in column_names:
|
||||
if is_chinese(identifier.value):
|
||||
if hasattr(identifier, "tokens") and identifier.value in column_names:
|
||||
if is_chinese(identifier.value):
|
||||
new_value = get_new_value(identifier.value)
|
||||
identifier.value = new_value
|
||||
identifier.normalized = new_value
|
||||
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
else:
|
||||
if hasattr(identifier, 'tokens'):
|
||||
if hasattr(identifier, "tokens"):
|
||||
for token in identifier.tokens:
|
||||
if isinstance(token, sqlparse.sql.Function):
|
||||
process_function(token)
|
||||
elif token.ttype in sqlparse.tokens.Name :
|
||||
elif token.ttype in sqlparse.tokens.Name:
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
@@ -179,9 +201,12 @@ def process_identifier(identifier, column_names=[]):
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
|
||||
|
||||
def get_new_value(value):
|
||||
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
|
||||
|
||||
|
||||
def process_function(function):
|
||||
function_params = list(function.get_parameters())
|
||||
# for param in function_params:
|
||||
@@ -191,15 +216,18 @@ def process_function(function):
|
||||
if isinstance(param, sqlparse.sql.Identifier):
|
||||
# 判断是否需要替换字段值
|
||||
# if is_chinese(param.value):
|
||||
# 替换字段值
|
||||
# 替换字段值
|
||||
new_value = get_new_value(param.value)
|
||||
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
|
||||
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
function_params[i].tokens = [
|
||||
sqlparse.sql.Token(sqlparse.tokens.Name, new_value)
|
||||
]
|
||||
print(str(function))
|
||||
|
||||
|
||||
def is_chinese(text):
|
||||
for char in text:
|
||||
if '一' <= char <= '鿿':
|
||||
if "一" <= char <= "鿿":
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -240,7 +268,7 @@ class ExcelReader:
|
||||
df_tmp = pd.read_csv(file_path, encoding=encoding)
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
encoding = encoding,
|
||||
encoding=encoding,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
else:
|
||||
@@ -280,7 +308,6 @@ class ExcelReader:
|
||||
logging.error("excel sql run error!", e)
|
||||
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
||||
|
||||
|
||||
def get_df_by_sql_ex(self, sql):
|
||||
colunms, values = self.run(sql)
|
||||
return pd.DataFrame(values, columns=colunms)
|
||||
|
@@ -16,7 +16,7 @@ class ChatWithPlugin(BaseChat):
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.plugin_selector = chat_param["select_param"]
|
||||
self.plugin_selector = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||
super().__init__(chat_param=chat_param)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
|
@@ -15,7 +15,6 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("in order to avoid chroma db atexit problem")
|
||||
os._exit(0)
|
||||
@@ -32,7 +31,6 @@ def async_db_summary(system_app: SystemApp):
|
||||
def server_init(args, system_app: SystemApp):
|
||||
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
|
||||
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
@@ -44,8 +42,6 @@ def server_init(args, system_app: SystemApp):
|
||||
# load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
"pilot.base_modules.agent.commands.built_in.audio_text",
|
||||
@@ -126,4 +122,3 @@ class WebWerverParameters(BaseParameters):
|
||||
},
|
||||
)
|
||||
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
||||
|
||||
|
@@ -29,6 +29,7 @@ def initialize_components(
|
||||
system_app.register_instance(controller)
|
||||
|
||||
from pilot.base_modules.agent.controller import module_agent
|
||||
|
||||
system_app.register_instance(module_agent)
|
||||
|
||||
_initialize_embedding_model(
|
||||
|
@@ -31,7 +31,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from pilot.base_modules.agent.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.base_modules.agent.commands.disply_type.show_chart_gen import (
|
||||
static_message_img_path,
|
||||
)
|
||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import (
|
||||
setup_logging,
|
||||
@@ -56,6 +58,8 @@ def swagger_monkey_patch(*args, **kwargs):
|
||||
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
|
||||
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
@@ -73,14 +77,14 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(api_v1, prefix="/api", tags=["Chat"])
|
||||
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
|
||||
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
|
||||
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
|
||||
|
||||
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
app.include_router(prompt_router, tags=["Prompt"])
|
||||
app.include_router(knowledge_router, tags=["Knowledge"])
|
||||
app.include_router(prompt_router, tags=["Prompt"])
|
||||
|
||||
|
||||
def mount_static_files(app):
|
||||
@@ -98,6 +102,7 @@ def mount_static_files(app):
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
||||
def _get_webserver_params(args: List[str] = None):
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
@@ -106,6 +111,7 @@ def _get_webserver_params(args: List[str] = None):
|
||||
)
|
||||
return WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
|
||||
|
||||
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
"""Initialize app
|
||||
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
|
||||
|
Reference in New Issue
Block a user