style:fmt

This commit is contained in:
aries_ckt
2023-10-17 13:55:19 +08:00
parent da87e40163
commit f65ca37a02
49 changed files with 582 additions and 496 deletions

View File

@@ -7,16 +7,19 @@ import hashlib
from http import HTTPStatus
from dashscope import Generation
def call_with_messages():
messages = [{'role': 'system', 'content': '你是生活助手机器人。'},
{'role': 'user', 'content': '如何做西红柿鸡蛋?'}]
messages = [
{"role": "system", "content": "你是生活助手机器人。"},
{"role": "user", "content": "如何做西红柿鸡蛋?"},
]
gen = Generation()
response = gen.call(
Generation.Models.qwen_turbo,
messages=messages,
stream=True,
top_p=0.8,
result_format='message', # set the result to be "message" format.
result_format="message", # set the result to be "message" format.
)
for response in response:
@@ -24,21 +27,24 @@ def call_with_messages():
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response.status_code == HTTPStatus.OK:
print(response.output) # The output text
print(response.output) # The output text
print(response.usage) # The usage information
else:
print(response.code) # The error code.
print(response.message) # The error message.
print(response.code) # The error code.
print(response.message) # The error message.
def build_access_token(api_key: str, secret_key: str) -> str:
"""
Generate Access token according AK, SK
Generate Access token according AK, SK
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
res = requests.get(url=url, params=params)
@@ -47,15 +53,15 @@ def build_access_token(api_key: str, secret_key: str) -> str:
def _calculate_md5(text: str) -> str:
md5 = hashlib.md5()
md5.update(text.encode("utf-8"))
encrypted = md5.hexdigest()
return encrypted
def baichuan_call():
url = "https://api.baichuan-ai.com/v1/stream/chat"
if __name__ == '__main__':
if __name__ == "__main__":
call_with_messages()

View File

@@ -1,4 +1,4 @@
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.my_plugin_db import MyPluginEntity, MyPluginDao
from .db.plugin_hub_db import PluginHubEntity, PluginHubDao
from .commands.command import execute_command, get_command

View File

@@ -9,6 +9,7 @@ from .generator import PluginPromptGenerator
from pilot.configs.config import Config
def _resolve_pathlike_command_args(command_args):
if "directory" in command_args and command_args["directory"] in {"", "/"}:
# todo
@@ -64,8 +65,6 @@ def execute_ai_response_json(
return result
def execute_command(
command_name: str,
arguments,
@@ -81,10 +80,8 @@ def execute_command(
str: The result of the command
"""
cmd = plugin_generator.command_registry.commands.get(command_name)
# If the command is found, call it with the provided arguments
if cmd:
try:
@@ -153,6 +150,3 @@ def get_command(response_json: Dict):
# All other errors, return "Error: + error message"
except Exception as e:
return "Error:", str(e)

View File

@@ -28,13 +28,13 @@ class Command:
"""
def __init__(
self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
self,
name: str,
description: str,
method: Callable[..., Any],
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
):
self.name = name
self.description = description
@@ -87,11 +87,12 @@ class CommandRegistry:
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
def is_valid_command(self, name:str)-> bool:
def is_valid_command(self, name: str) -> bool:
if name not in self.commands:
return False
else:
return True
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
@@ -129,23 +130,23 @@ class CommandRegistry:
attr = getattr(module, attr_name)
# Register decorated functions
if hasattr(attr, AUTO_GPT_COMMAND_IDENTIFIER) and getattr(
attr, AUTO_GPT_COMMAND_IDENTIFIER
attr, AUTO_GPT_COMMAND_IDENTIFIER
):
self.register(attr.command)
# Register command classes
elif (
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
inspect.isclass(attr) and issubclass(attr, Command) and attr != Command
):
cmd_instance = attr()
self.register(cmd_instance)
def command(
name: str,
description: str,
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
name: str,
description: str,
signature: str = "",
enabled: bool = True,
disabled_reason: Optional[str] = None,
) -> Callable[..., Any]:
"""The command decorator is used to create Command objects from ordinary functions."""
@@ -241,34 +242,60 @@ class ApiCall:
else:
md_tag_end = "```"
for tag in error_md_tags:
all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
all_context = all_context.replace(tag + " " +api_context + " " + md_tag_end, api_context)
all_context = all_context.replace(
tag + api_context + md_tag_end, api_context
)
all_context = all_context.replace(
tag + "\n" + api_context + "\n" + md_tag_end, api_context
)
all_context = all_context.replace(
tag + " " + api_context + " " + md_tag_end, api_context
)
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
error_mk_tags = ["```", "```python", "```xml"]
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
call_context_map = extract_content_open_ending(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
if api_status.api_result:
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, api_status.api_result)
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = all_context.replace(
api_context, api_status.api_result
)
else:
if api_status.status == Status.FAILED.value:
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = all_context.replace(
api_context,
f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """,
)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
all_context = self.__deal_error_md_tags(
all_context, api_context
)
all_context = all_context.replace(
api_context,
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
)
else:
all_context = self.__deal_error_md_tags(all_context, api_context, False)
all_context = all_context.replace(api_context, self.to_view_text(api_status))
all_context = self.__deal_error_md_tags(
all_context, api_context, False
)
all_context = all_context.replace(
api_context, self.to_view_text(api_status)
)
else:
# not ready api call view change
@@ -276,27 +303,34 @@ class ApiCall:
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context , api_context)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
all_context = all_context.replace(tag + api_context, api_context)
all_context = all_context.replace(
api_context,
f'\n<span style="color:green">Waiting...{cost_str}S</span>\n',
)
return all_context
def update_from_context(self, all_context):
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
api_context_map = extract_content(
all_context, self.agent_prefix, self.agent_end, True
)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
api_name = api_call_element.find('name').text
if api_name.find("[")>=0 or api_name.find("]")>=0:
api_name = api_call_element.find("name").text
if api_name.find("[") >= 0 or api_name.find("]") >= 0:
api_name = api_name.replace("[", "").replace("]", "")
api_args = {}
args_elements = api_call_element.find('args')
args_elements = api_call_element.find("args")
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
api_status = self.plugin_status_map.get(api_context)
if api_status is None:
api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
api_status = PluginStatus(
name=api_name, location=[api_index], args=api_args
)
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
@@ -304,20 +338,20 @@ class ApiCall:
def __to_view_param_str(self, api_status):
param = {}
if api_status.name:
param['name'] = api_status.name
param['status'] = api_status.status
param["name"] = api_status.name
param["status"] = api_status.status
if api_status.logo_url:
param['logo'] = api_status.logo_url
param["logo"] = api_status.logo_url
if api_status.err_msg:
param['err_msg'] = api_status.err_msg
param["err_msg"] = api_status.err_msg
if api_status.api_result:
param['result'] = api_status.api_result
param["result"] = api_status.api_result
return json.dumps(param)
def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element('dbgpt-view')
api_call_element = ET.Element("dbgpt-view")
api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
@@ -332,7 +366,9 @@ class ApiCall:
value.status = Status.RUNNING.value
logging.info(f"插件执行:{value.name},{value.args}")
try:
value.api_result = execute_command(value.name, value.args, self.plugin_generator)
value.api_result = execute_command(
value.name, value.args, self.plugin_generator
)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
@@ -350,15 +386,19 @@ class ApiCall:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args['sql']
sql = value.args["sql"]
if sql:
param = {
"df": sql_run_func(sql),
}
if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(value.name, **param)
value.api_result = self.display_registry.call(
value.name, **param
)
else:
value.api_result = self.display_registry.call("response_table", **param)
value.api_result = self.display_registry.call(
"response_table", **param
)
value.status = Status.COMPLETED.value
except Exception as e:
@@ -366,4 +406,3 @@ class ApiCall:
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)

View File

@@ -15,6 +15,7 @@ from matplotlib.font_manager import FontManager
from pilot.common.string_utils import is_scientific_notation
import logging
logger = logging.getLogger(__name__)
@@ -88,12 +89,13 @@ def zh_font_set():
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
def format_axis(value, pos):
# 判断是否为数字
if is_scientific_notation(value):
# 判断是否需要进行非科学计数法格式化
return '{:.2f}'.format(value)
return "{:.2f}".format(value)
return value
@@ -102,7 +104,7 @@ def format_axis(value, pos):
"Line chart display, used to display comparative trend analysis data",
'"df":"<data frame>"',
)
def response_line_chart( df: DataFrame) -> str:
def response_line_chart(df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data")
@@ -143,9 +145,15 @@ def response_line_chart( df: DataFrame) -> str:
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
df,
id_vars=x,
value_vars=num_colmns,
var_name="line",
value_name="Value",
)
sns.lineplot(
data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2"
)
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
@@ -154,7 +162,7 @@ def response_line_chart( df: DataFrame) -> str:
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, dpi=100, transparent=True)
plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@@ -168,7 +176,7 @@ def response_line_chart( df: DataFrame) -> str:
"Histogram, suitable for comparative analysis of multiple target values",
'"df":"<data frame>"',
)
def response_bar_chart( df: DataFrame) -> str:
def response_bar_chart(df: DataFrame) -> str:
logger.info(f"response_bar_chart")
if df.size <= 0:
raise ValueError("No Data")
@@ -246,7 +254,7 @@ def response_bar_chart( df: DataFrame) -> str:
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, dpi=100,transparent=True)
plt.savefig(chart_path, dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img

View File

@@ -3,8 +3,10 @@ from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
logger = logging.getLogger(__name__)
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",

View File

@@ -3,6 +3,7 @@ from pandas import DataFrame
from pilot.base_modules.agent.commands.command_mange import command
import logging
logger = logging.getLogger(__name__)
@@ -22,7 +23,7 @@ def response_data_text(df: DataFrame) -> str:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = html.replace("\n", " ")
text_info = html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:

View File

@@ -133,5 +133,3 @@ class PluginPromptGenerator:
def generate_commands_string(self) -> str:
return f"{self._generate_numbered_list(self.commands, item_type='command')}"

View File

@@ -5,11 +5,12 @@ class PluginStorageType(Enum):
Git = "git"
Oss = "oss"
class Status(Enum):
TODO = "todo"
RUNNING = 'running'
FAILED = 'failed'
COMPLETED = 'completed'
RUNNING = "running"
FAILED = "failed"
COMPLETED = "completed"
class ApiTagType(Enum):

View File

@@ -16,7 +16,13 @@ from pilot.openapi.api_view_model import (
Result,
)
from .model import PluginHubParam, PagenationFilter, PagenationResult, PluginHubFilter, MyPluginFilter
from .model import (
PluginHubParam,
PagenationFilter,
PagenationResult,
PluginHubFilter,
MyPluginFilter,
)
from .hub.agent_hub import AgentHub
from .db.plugin_hub_db import PluginHubEntity
from .plugins_util import scan_plugins
@@ -33,17 +39,18 @@ class ModuleAgent(BaseComponent, ABC):
name = ComponentType.AGENT_HUB
def __init__(self):
#load plugins
# load plugins
self.plugins = scan_plugins(PLUGINS_DIR)
def init_app(self, system_app: SystemApp):
system_app.app.include_router(router, prefix="/api", tags=["Agent"])
def refresh_plugins(self):
self.plugins = scan_plugins(PLUGINS_DIR)
def load_select_plugin(self, generator:PluginPromptGenerator, select_plugins:List[str])->PluginPromptGenerator:
def load_select_plugin(
self, generator: PluginPromptGenerator, select_plugins: List[str]
) -> PluginPromptGenerator:
logger.info(f"load_select_plugin:{select_plugins}")
# load select plugin
for plugin in self.plugins:
@@ -53,6 +60,7 @@ class ModuleAgent(BaseComponent, ABC):
generator = plugin.post_prompt(generator)
return generator
module_agent = ModuleAgent()
@@ -61,25 +69,28 @@ async def agent_hub_update(update_param: PluginHubParam = Body()):
logger.info(f"agent_hub_update:{update_param.__dict__}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
agent_hub.refresh_hub_from_git(update_param.url, update_param.branch, update_param.authorization)
agent_hub.refresh_hub_from_git(
update_param.url, update_param.branch, update_param.authorization
)
return Result.succ(None)
except Exception as e:
logger.error("Agent Hub Update Error!", e)
return Result.faild(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
agent_hub = AgentHub(PLUGINS_DIR)
filter_enetity:PluginHubEntity = PluginHubEntity()
filter_enetity: PluginHubEntity = PluginHubEntity()
if filter.filter:
attrs = vars(filter.filter) # 获取原始对象的属性字典
for attr, value in attrs.items():
setattr(filter_enetity, attr, value) # 设置拷贝对象的属性值
datas, total_pages, total_count = agent_hub.hub_dao.list(filter_enetity, filter.page_index, filter.page_size)
datas, total_pages, total_count = agent_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result.page_index = filter.page_index
result.page_size = filter.page_size
@@ -89,11 +100,12 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
async def my_agents(user:str= None):
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agent_hub = AgentHub(PLUGINS_DIR)
agents = agent_hub.get_my_plugin(user)
agents = agent_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
@@ -102,7 +114,7 @@ async def my_agents(user:str= None):
@router.post("/v1/agent/install", response_model=Result[str])
async def agent_install(plugin_name:str, user: str = None):
async def agent_install(plugin_name: str, user: str = None):
logger.info(f"agent_install:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
@@ -111,14 +123,13 @@ async def agent_install(plugin_name:str, user: str = None):
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
except Exception as e:
logger.error("Plugin Install Error!", e)
return Result.faild(code="E0021", msg=f"Plugin Install Error {e}")
@router.post("/v1/agent/uninstall", response_model=Result[str])
async def agent_uninstall(plugin_name:str, user: str = None):
async def agent_uninstall(plugin_name: str, user: str = None):
logger.info(f"agent_uninstall:{plugin_name},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
@@ -126,19 +137,18 @@ async def agent_uninstall(plugin_name:str, user: str = None):
module_agent.refresh_plugins()
return Result.succ(None)
except Exception as e:
except Exception as e:
logger.error("Plugin Uninstall Error!", e)
return Result.faild(code="E0022", msg=f"Plugin Uninstall Error {e}")
@router.post("/v1/personal/agent/upload", response_model=Result[str])
async def personal_agent_upload( doc_file: UploadFile = File(...), user: str =None):
async def personal_agent_upload(doc_file: UploadFile = File(...), user: str = None):
logger.info(f"personal_agent_upload:{doc_file.filename},{user}")
try:
agent_hub = AgentHub(PLUGINS_DIR)
await agent_hub.upload_my_plugin(doc_file, user)
return Result.succ(None)
except Exception as e:
except Exception as e:
logger.error("Upload Personal Plugin Error!", e)
return Result.faild(code="E0023", msg=f"Upload Personal Plugin Error {e}")

View File

@@ -9,30 +9,33 @@ from pilot.base_modules.meta_data.base_dao import BaseDao
from pilot.base_modules.meta_data.meta_data import Base, engine, session
class MyPluginEntity(Base):
__tablename__ = 'my_plugin'
__tablename__ = "my_plugin"
id = Column(Integer, primary_key=True, comment="autoincrement id")
tenant = Column(String(255), nullable=True, comment="user's tenant")
user_code = Column(String(255), nullable=False, comment="user code")
user_name = Column(String(255), nullable=True, comment="user name")
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(Integer, nullable=True, default=0, comment="plugin total use count")
succ_count = Column(Integer, nullable=True, default=0, comment="plugin total success count")
created_at = Column(DateTime, default=datetime.utcnow, comment="plugin install time")
__table_args__ = (
UniqueConstraint('user_code','name', name="uk_name"),
file_name = Column(String(255), nullable=False, comment="plugin package file name")
type = Column(String(255), comment="plugin type")
version = Column(String(255), comment="plugin version")
use_count = Column(
Integer, nullable=True, default=0, comment="plugin total use count"
)
succ_count = Column(
Integer, nullable=True, default=0, comment="plugin total success count"
)
created_at = Column(
DateTime, default=datetime.utcnow, comment="plugin install time"
)
__table_args__ = (UniqueConstraint("user_code", "name", name="uk_name"),)
class MyPluginDao(BaseDao[MyPluginEntity]):
def __init__(self):
super().__init__(
database="dbgpt", orm_base=Base, db_engine =engine , session= session
database="dbgpt", orm_base=Base, db_engine=engine, session=session
)
def add(self, engity: MyPluginEntity):
@@ -60,87 +63,61 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
session.commit()
return updated.id
def get_by_user(self, user: str)->list[MyPluginEntity]:
def get_by_user(self, user: str) -> list[MyPluginEntity]:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
if user:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == user
)
if user:
my_plugins = my_plugins.filter(MyPluginEntity.user_code == user)
result = my_plugins.all()
session.close()
return result
def list(self, query: MyPluginEntity, page=1, page_size=20)->list[MyPluginEntity]:
def list(self, query: MyPluginEntity, page=1, page_size=20) -> list[MyPluginEntity]:
session = self.get_session()
my_plugins = session.query(MyPluginEntity)
all_count = my_plugins.count()
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.name == query.name
)
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.tenant is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.tenant == query.tenant
)
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.type is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.type == query.type
)
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.user_code is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == query.user_code
)
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_name == query.user_name
)
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
my_plugins = my_plugins.order_by(MyPluginEntity.id.desc())
my_plugins = my_plugins.offset((page - 1) * page_size).limit( page_size)
my_plugins = my_plugins.offset((page - 1) * page_size).limit(page_size)
result = my_plugins.all()
session.close()
total_pages = all_count // page_size
if all_count % page_size != 0:
total_pages += 1
return result, total_pages, all_count
def count(self, query: MyPluginEntity):
session = self.get_session()
my_plugins = session.query(func.count(MyPluginEntity.id))
if query.id is not None:
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
if query.name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.name == query.name
)
my_plugins = my_plugins.filter(MyPluginEntity.name == query.name)
if query.type is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.type == query.type
)
my_plugins = my_plugins.filter(MyPluginEntity.type == query.type)
if query.tenant is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.tenant == query.tenant
)
my_plugins = my_plugins.filter(MyPluginEntity.tenant == query.tenant)
if query.user_code is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_code == query.user_code
)
my_plugins = my_plugins.filter(MyPluginEntity.user_code == query.user_code)
if query.user_name is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.user_name == query.user_name
)
my_plugins = my_plugins.filter(MyPluginEntity.user_name == query.user_name)
count = my_plugins.scalar()
session.close()
return count
def delete(self, plugin_id: int):
session = self.get_session()
if plugin_id is None:
@@ -148,9 +125,7 @@ class MyPluginDao(BaseDao[MyPluginEntity]):
query = MyPluginEntity(id=plugin_id)
my_plugins = session.query(MyPluginEntity)
if query.id is not None:
my_plugins = my_plugins.filter(
MyPluginEntity.id == query.id
)
my_plugins = my_plugins.filter(MyPluginEntity.id == query.id)
my_plugins.delete()
session.commit()
session.close()

View File

@@ -10,8 +10,10 @@ from pilot.base_modules.meta_data.meta_data import Base, engine, session
class PluginHubEntity(Base):
__tablename__ = 'plugin_hub'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
__tablename__ = "plugin_hub"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
name = Column(String(255), unique=True, nullable=False, comment="plugin name")
description = Column(String(255), nullable=False, comment="plugin description")
author = Column(String(255), nullable=True, comment="plugin author")
@@ -25,8 +27,8 @@ class PluginHubEntity(Base):
installed = Column(Integer, default=False, comment="plugin already installed count")
__table_args__ = (
UniqueConstraint('name', name="uk_name"),
Index('idx_q_type', 'type'),
UniqueConstraint("name", name="uk_name"),
Index("idx_q_type", "type"),
)
@@ -38,7 +40,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def add(self, engity: PluginHubEntity):
session = self.get_session()
timezone = pytz.timezone('Asia/Shanghai')
timezone = pytz.timezone("Asia/Shanghai")
plugin_hub = PluginHubEntity(
name=engity.name,
author=engity.author,
@@ -64,7 +66,9 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
finally:
session.close()
def list(self, query: PluginHubEntity, page=1, page_size=20) -> list[PluginHubEntity]:
def list(
self, query: PluginHubEntity, page=1, page_size=20
) -> list[PluginHubEntity]:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
all_count = plugin_hubs.count()
@@ -72,17 +76,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == query.name
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.type == query.type
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.author == query.author
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
@@ -110,9 +108,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
def get_by_name(self, name: str) -> PluginHubEntity:
session = self.get_session()
plugin_hubs = session.query(PluginHubEntity)
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == name
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == name)
result = plugin_hubs.first()
session.close()
return result
@@ -123,17 +119,11 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
if query.id is not None:
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == query.id)
if query.name is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.name == query.name
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.name == query.name)
if query.type is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.type == query.type
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.type == query.type)
if query.author is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.author == query.author
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.author == query.author)
if query.storage_channel is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.storage_channel == query.storage_channel
@@ -148,9 +138,7 @@ class PluginHubDao(BaseDao[PluginHubEntity]):
raise Exception("plugin_id is None")
plugin_hubs = session.query(PluginHubEntity)
if plugin_id is not None:
plugin_hubs = plugin_hubs.filter(
PluginHubEntity.id == plugin_id
)
plugin_hubs = plugin_hubs.filter(PluginHubEntity.id == plugin_id)
plugin_hubs.delete()
session.commit()
session.close()

View File

@@ -4,7 +4,7 @@ import os
import glob
import shutil
from fastapi import UploadFile
from typing import Any
from typing import Any
import tempfile
from ..db.plugin_hub_db import PluginHubEntity, PluginHubDao
@@ -38,7 +38,9 @@ class AgentHub:
download_param = json.loads(plugin_entity.download_param)
branch_name = download_param.get("branch_name")
authorization = download_param.get("authorization")
file_name = self.__download_from_git(plugin_entity.storage_url, branch_name, authorization)
file_name = self.__download_from_git(
plugin_entity.storage_url, branch_name, authorization
)
# add to my plugins and edit hub status
plugin_entity.installed = plugin_entity.installed + 1
@@ -65,7 +67,9 @@ class AgentHub:
logger.error("install pluguin exception!", e)
raise ValueError(f"Install Plugin {plugin_name} Faild! {str(e)}")
else:
raise ValueError(f"Unsupport Storage Channel {plugin_entity.storage_channel}!")
raise ValueError(
f"Unsupport Storage Channel {plugin_entity.storage_channel}!"
)
else:
raise ValueError(f"Can't Find Plugin {plugin_name}!")
@@ -75,7 +79,9 @@ class AgentHub:
plugin_entity.installed = plugin_entity.installed - 1
with self.hub_dao.get_session() as session:
try:
my_plugin_q = session.query(MyPluginEntity).filter(MyPluginEntity.name == plugin_name)
my_plugin_q = session.query(MyPluginEntity).filter(
MyPluginEntity.name == plugin_name
)
if user:
my_plugin_q.filter(MyPluginEntity.user_code == user)
my_plugin_q.delete()
@@ -92,10 +98,10 @@ class AgentHub:
have_installed = True
break
if not have_installed:
plugin_repo_name = plugin_entity.storage_url.replace(".git", "").strip('/').split('/')[-1]
files = glob.glob(
os.path.join(self.plugin_dir, f"{plugin_repo_name}*")
plugin_repo_name = (
plugin_entity.storage_url.replace(".git", "").strip("/").split("/")[-1]
)
files = glob.glob(os.path.join(self.plugin_dir, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
@@ -109,9 +115,16 @@ class AgentHub:
my_plugin_entity.version = hub_plugin.version
return my_plugin_entity
def refresh_hub_from_git(self, github_repo: str = None, branch_name: str = None, authorization: str = None):
def refresh_hub_from_git(
self,
github_repo: str = None,
branch_name: str = None,
authorization: str = None,
):
logger.info("refresh_hub_by_git start!")
update_from_git(self.temp_hub_file_path, github_repo, branch_name, authorization)
update_from_git(
self.temp_hub_file_path, github_repo, branch_name, authorization
)
git_plugins = scan_plugins(self.temp_hub_file_path)
try:
for git_plugin in git_plugins:
@@ -123,13 +136,13 @@ class AgentHub:
plugin_hub_info.type = ""
plugin_hub_info.storage_channel = PluginStorageType.Git.value
plugin_hub_info.storage_url = DEFAULT_PLUGIN_REPO
plugin_hub_info.author = getattr(git_plugin, '_author', 'DB-GPT')
plugin_hub_info.email = getattr(git_plugin, '_email', '')
plugin_hub_info.author = getattr(git_plugin, "_author", "DB-GPT")
plugin_hub_info.email = getattr(git_plugin, "_email", "")
download_param = {}
if branch_name:
download_param['branch_name'] = branch_name
download_param["branch_name"] = branch_name
if authorization and len(authorization) > 0:
download_param['authorization'] = authorization
download_param["authorization"] = authorization
plugin_hub_info.download_param = json.dumps(download_param)
plugin_hub_info.installed = 0
@@ -140,15 +153,12 @@ class AgentHub:
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")
async def upload_my_plugin(self, doc_file: UploadFile, user: Any=Default_User):
async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User):
# We can not move temp file in windows system when we open file in context of `with`
file_path = os.path.join(self.plugin_dir, doc_file.filename)
if os.path.exists(file_path):
os.remove(file_path)
tmp_fd, tmp_path = tempfile.mkstemp(
dir=os.path.join(self.plugin_dir)
)
tmp_fd, tmp_path = tempfile.mkstemp(dir=os.path.join(self.plugin_dir))
with os.fdopen(tmp_fd, "wb") as tmp:
tmp.write(await doc_file.read())
shutil.move(
@@ -158,14 +168,14 @@ class AgentHub:
my_plugins = scan_plugins(self.plugin_dir, doc_file.filename)
if user is None or len(user) <=0:
if user is None or len(user) <= 0:
user = Default_User
for my_plugin in my_plugins:
my_plugin_entiy = MyPluginEntity()
my_plugin_entiy.name = my_plugin._name
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.version = my_plugin._version
my_plugin_entiy.type = "Personal"
my_plugin_entiy.user_code = user
my_plugin_entiy.user_name = user
@@ -183,4 +193,3 @@ class AgentHub:
if not user:
user = Default_User
return self.my_lugin_dao.get_by_user(user)

View File

@@ -3,32 +3,35 @@ from dataclasses import dataclass
from pydantic import BaseModel, Field
from typing import TypeVar, Generic, Any
T = TypeVar('T')
T = TypeVar("T")
class PagenationFilter(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
page_size: int = 20
filter: T = None
class PagenationResult(BaseModel, Generic[T]):
page_index: int = 1
page_size: int = 20
page_size: int = 20
total_page: int = 0
total_row_count: int = 0
datas: List[T] = []
def to_dic(self):
data_dicts =[]
data_dicts = []
for item in self.datas:
data_dicts.append(item.__dict__)
return {
'page_index': self.page_index,
'page_size': self.page_size,
'total_page': self.total_page,
'total_row_count': self.total_row_count,
'datas': data_dicts
"page_index": self.page_index,
"page_size": self.page_size,
"total_page": self.total_page,
"total_row_count": self.total_row_count,
"datas": data_dicts,
}
@dataclass
class PluginHubFilter(BaseModel):
name: str
@@ -53,9 +56,14 @@ class MyPluginFilter(BaseModel):
class PluginHubParam(BaseModel):
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field("https://github.com/eosphoros-ai/DB-GPT-Plugins.git", description="Plugin storage url")
branch: Optional[str] = Field("main", description="github download branch", nullable=True)
authorization: Optional[str] = Field(None, description="github download authorization", nullable=True)
channel: Optional[str] = Field("git", description="Plugin storage channel")
url: Optional[str] = Field(
"https://github.com/eosphoros-ai/DB-GPT-Plugins.git",
description="Plugin storage url",
)
branch: Optional[str] = Field(
"main", description="github download branch", nullable=True
)
authorization: Optional[str] = Field(
None, description="github download authorization", nullable=True
)

View File

@@ -117,7 +117,7 @@ def load_native_plugins(cfg: Config):
t.start()
def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTemplate]:
def __scan_plugin_file(file_path, debug: bool = False) -> List[AutoGPTPluginTemplate]:
logger.info(f"__scan_plugin_file:{file_path},{debug}")
loaded_plugins = []
if moduleList := inspect_zip_for_modules(str(file_path), debug):
@@ -133,14 +133,17 @@ def __scan_plugin_file(file_path, debug: bool = False)-> List[AutoGPTPluginTempl
a_module = getattr(zipped_module, key)
a_keys = dir(a_module)
if (
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
"_abc_impl" in a_keys
and a_module.__name__ != "AutoGPTPluginTemplate"
# and denylist_allowlist_check(a_module.__name__, cfg)
):
loaded_plugins.append(a_module())
return loaded_plugins
def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = False) -> List[AutoGPTPluginTemplate]:
def scan_plugins(
plugins_file_path: str, file_name: str = "", debug: bool = False
) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.
Args:
@@ -159,7 +162,7 @@ def scan_plugins(plugins_file_path: str, file_name: str = "", debug: bool = Fals
loaded_plugins = __scan_plugin_file(plugin_path)
else:
for plugin_path in plugins_path.glob("*.zip"):
loaded_plugins.extend(__scan_plugin_file(plugin_path))
loaded_plugins.extend(__scan_plugin_file(plugin_path))
if loaded_plugins:
logger.info(f"\nPlugins found: {len(loaded_plugins)}\n" "--------------------")
@@ -192,17 +195,23 @@ def denylist_allowlist_check(plugin_name: str, cfg: Config) -> bool:
return ack.lower() == cfg.authorise_key
def update_from_git(download_path: str, github_repo: str = "", branch_name: str = "main",
authorization: str = None):
def update_from_git(
download_path: str,
github_repo: str = "",
branch_name: str = "main",
authorization: str = None,
):
os.makedirs(download_path, exist_ok=True)
if github_repo:
if github_repo.index("github.com") <= 0:
raise ValueError("Not a correct Github repository address" + github_repo)
github_repo = github_repo.replace(".git", "")
url = github_repo + "/archive/refs/heads/" + branch_name + ".zip"
plugin_repo_name = github_repo.strip('/').split('/')[-1]
plugin_repo_name = github_repo.strip("/").split("/")[-1]
else:
url = "https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
url = (
"https://github.com/eosphoros-ai/DB-GPT-Plugins/archive/refs/heads/main.zip"
)
plugin_repo_name = "DB-GPT-Plugins"
try:
session = requests.Session()
@@ -216,14 +225,14 @@ def update_from_git(download_path: str, github_repo: str = "", branch_name: str
if response.status_code == 200:
plugins_path_path = Path(download_path)
files = glob.glob(
os.path.join(plugins_path_path, f"{plugin_repo_name}*")
)
files = glob.glob(os.path.join(plugins_path_path, f"{plugin_repo_name}*"))
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
file_name = f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
file_name = (
f"{plugins_path_path}/{plugin_repo_name}-{branch_name}-{time_str}.zip"
)
print(file_name)
with open(file_name, "wb") as f:
f.write(response.content)

View File

@@ -1,5 +1,4 @@
class ModuleMangeApi:
def module_name(self):
pass

View File

@@ -1,11 +1,16 @@
from typing import TypeVar, Generic, List, Any
from sqlalchemy.orm import sessionmaker
T = TypeVar('T')
T = TypeVar("T")
class BaseDao(Generic[T]):
def __init__(
self, orm_base=None, database: str = None, db_engine: Any = None, session: Any = None,
self,
orm_base=None,
database: str = None,
db_engine: Any = None,
session: Any = None,
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base

View File

@@ -7,7 +7,7 @@ import fnmatch
from datetime import datetime
from typing import Optional, Type, TypeVar
from sqlalchemy import create_engine,DateTime, String, func, MetaData
from sqlalchemy import create_engine, DateTime, String, func, MetaData
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped
@@ -32,16 +32,17 @@ db_name = "dbgpt"
db_path = default_db_path + f"/{db_name}.db"
connection = sqlite3.connect(db_path)
if CFG.LOCAL_DB_TYPE == 'mysql':
engine_temp = create_engine(f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
if CFG.LOCAL_DB_TYPE == "mysql":
engine_temp = create_engine(
f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
)
# check and auto create mysqldatabase
try:
# try to connect
@@ -53,20 +54,19 @@ if CFG.LOCAL_DB_TYPE == 'mysql':
# if connect failed, create dbgpt database
logger.error(f"{db_name} not connect success!")
engine = create_engine(f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ f"/{db_name}"
)
engine = create_engine(
f"mysql+pymysql://"
+ quote(CFG.LOCAL_DB_USER)
+ ":"
+ quote(CFG.LOCAL_DB_PASSWORD)
+ "@"
+ CFG.LOCAL_DB_HOST
+ ":"
+ str(CFG.LOCAL_DB_PORT)
+ f"/{db_name}"
)
else:
engine = create_engine(f'sqlite:///{db_path}')
engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@@ -81,16 +81,16 @@ Base = declarative_base()
alembic_ini_path = default_db_path + "/alembic.ini"
alembic_cfg = AlembicConfig(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
alembic_cfg.set_main_option("sqlalchemy.url", str(engine.url))
os.makedirs(default_db_path + "/alembic", exist_ok=True)
os.makedirs(default_db_path + "/alembic/versions", exist_ok=True)
alembic_cfg.set_main_option('script_location', default_db_path + "/alembic")
alembic_cfg.set_main_option("script_location", default_db_path + "/alembic")
# 将模型和会话传递给Alembic配置
alembic_cfg.attributes['target_metadata'] = Base.metadata
alembic_cfg.attributes['session'] = session
alembic_cfg.attributes["target_metadata"] = Base.metadata
alembic_cfg.attributes["session"] = session
# # 创建表
@@ -106,7 +106,7 @@ def ddl_init_and_upgrade():
# command.upgrade(alembic_cfg, 'head')
# subprocess.run(["alembic", "revision", "--autogenerate", "-m", "Added account table"])
with engine.connect() as connection:
alembic_cfg.attributes['connection'] = connection
alembic_cfg.attributes["connection"] = connection
heads = command.heads(alembic_cfg)
print("heads:" + str(heads))

View File

@@ -1,2 +1 @@

View File

@@ -1,27 +1,30 @@
import re
def is_all_chinese(text):
### Determine whether the string is pure Chinese
pattern = re.compile(r'^[一-龥]+$')
pattern = re.compile(r"^[一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_number_chinese(text):
### Determine whether the string is numbers and Chinese
pattern = re.compile(r'^[\d一-龥]+$')
pattern = re.compile(r"^[\d一-龥]+$")
match = re.match(pattern, text)
return match is not None
def is_chinese_include_number(text):
### Determine whether the string is pure Chinese or Chinese containing numbers
pattern = re.compile(r'^[一-龥]+[\d一-龥]*$')
pattern = re.compile(r"^[一-龥]+[\d一-龥]*$")
match = re.match(pattern, text)
return match is not None
def is_scientific_notation(string):
# 科学计数法的正则表达式
pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$'
pattern = r"^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$"
# 使用正则表达式匹配字符串
match = re.match(pattern, str(string))
# 判断是否匹配成功
@@ -30,28 +33,30 @@ def is_scientific_notation(string):
else:
return False
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map ={}
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index:end_index + len(s2)]
extracted_content = long_string[start_index : end_index + len(s2)]
else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1):end_index]
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1)
while start_index != -1:
if long_string.find(s2, start_index) <=0:
if long_string.find(s2, start_index) <= 0:
end_index = len(long_string)
else:
if is_include:
@@ -59,17 +64,16 @@ def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
else:
end_index = long_string.find(s2, start_index + len(s1))
if is_include:
extracted_content = long_string[start_index:end_index + len(s2)]
extracted_content = long_string[start_index : end_index + len(s2)]
else:
extracted_content = long_string[start_index + len(s1):end_index]
extracted_content = long_string[start_index + len(s1) : end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index= long_string.find(s1, start_index + 1)
start_index = long_string.find(s1, start_index + 1)
return match_map
if __name__=="__main__":
if __name__ == "__main__":
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"

View File

@@ -60,7 +60,9 @@ class Config(metaclass=Singleton):
self.zhipu_proxy_api_key = os.getenv("ZHIPU_PROXY_API_KEY")
if self.zhipu_proxy_api_key:
os.environ["zhipu_proxyllm_proxy_api_key"] = self.zhipu_proxy_api_key
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv("ZHIPU_MODEL_VERSION")
os.environ["zhipu_proxyllm_proxyllm_backend"] = os.getenv(
"ZHIPU_MODEL_VERSION"
)
# wenxin
self.wenxin_proxy_api_key = os.getenv("WEN_XIN_API_KEY")
@@ -68,7 +70,9 @@ class Config(metaclass=Singleton):
self.wenxin_model_version = os.getenv("WEN_XIN_MODEL_VERSION")
if self.wenxin_proxy_api_key and self.wenxin_proxy_api_secret:
os.environ["wenxin_proxyllm_proxy_api_key"] = self.wenxin_proxy_api_key
os.environ["wenxin_proxyllm_proxy_api_secret"] = self.wenxin_proxy_api_secret
os.environ[
"wenxin_proxyllm_proxy_api_secret"
] = self.wenxin_proxy_api_secret
os.environ["wenxin_proxyllm_proxyllm_backend"] = self.wenxin_model_version
# xunfei spark
@@ -91,7 +95,6 @@ class Config(metaclass=Singleton):
os.environ["bc_proxyllm_proxy_api_secret"] = self.bc_proxy_api_secret
os.environ["bc_proxyllm_proxyllm_backend"] = self.bc_model_version
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
@@ -172,7 +175,6 @@ class Config(metaclass=Singleton):
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)
self.LOCAL_DB_MANAGE = None
###dbgpt meta info database connection configuration
@@ -190,7 +192,6 @@ class Config(metaclass=Singleton):
self.CHAT_HISTORY_STORE_TYPE = os.getenv("CHAT_HISTORY_STORE_TYPE", "duckdb")
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"

View File

@@ -4,10 +4,13 @@ from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
class ConnectConfigEntity(Base):
__tablename__ = 'connect_config'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
db_type = Column(String(255), nullable=False, comment="db type")
__tablename__ = "connect_config"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
db_type = Column(String(255), nullable=False, comment="db type")
db_name = Column(String(255), nullable=False, comment="db name")
db_path = Column(String(255), nullable=True, comment="file db path")
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
@@ -17,8 +20,8 @@ class ConnectConfigEntity(Base):
comment = Column(Text, nullable=True, comment="db comment")
__table_args__ = (
UniqueConstraint('db_name', name="uk_db"),
Index('idx_q_db_type', 'db_type'),
UniqueConstraint("db_name", name="uk_db"),
Index("idx_q_db_type", "db_type"),
)
@@ -43,9 +46,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
raise Exception("db_name is None")
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(
ConnectConfigEntity.db_name == db_name
)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
db_connect.delete()
session.commit()
session.close()
@@ -53,10 +54,7 @@ class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
def get_by_name(self, db_name: str) -> ConnectConfigEntity:
session = self.get_session()
db_connect = session.query(ConnectConfigEntity)
db_connect = db_connect.filter(
ConnectConfigEntity.db_name == db_name
)
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
result = db_connect.first()
session.close()
return result

View File

@@ -5,15 +5,17 @@ from typing import List
from enum import Enum
from pilot.scene.message import OnceConversation
class MemoryStoreType(Enum):
File= 'file'
Memory = 'memory'
DB = 'db'
DuckDb = 'duckdb'
File = "file"
Memory = "memory"
DB = "db"
DuckDb = "duckdb"
class BaseChatHistoryMemory(ABC):
store_type: MemoryStoreType
def __init__(self):
self.conversations: List[OnceConversation] = []

View File

@@ -4,9 +4,7 @@ from pilot.configs.config import Config
CFG = Config()
class ChatHistory:
def __init__(self):
self.memory_type = MemoryStoreType.DB.value
self.mem_store_class_map = {}
@@ -14,15 +12,16 @@ class ChatHistory:
from .store_type.file_history import FileHistoryMemory
from .store_type.meta_db_history import DbHistoryMemory
from .store_type.mem_history import MemHistoryMemory
self.mem_store_class_map[DuckdbHistoryMemory.store_type] = DuckdbHistoryMemory
self.mem_store_class_map[FileHistoryMemory.store_type] = FileHistoryMemory
self.mem_store_class_map[DbHistoryMemory.store_type] = DbHistoryMemory
self.mem_store_class_map[MemHistoryMemory.store_type] = MemHistoryMemory
def get_store_instance(self, chat_session_id):
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(chat_session_id)
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)(
chat_session_id
)
def get_store_cls(self):
return self.mem_store_class_map.get(CFG.CHAT_HISTORY_STORE_TYPE)

View File

@@ -4,20 +4,28 @@ from typing import List
from sqlalchemy import Column, Integer, String, Index, DateTime, func, Boolean, Text
from sqlalchemy import UniqueConstraint
class ChatHistoryEntity(Base):
__tablename__ = 'chat_history'
id = Column(Integer, primary_key=True, autoincrement=True, comment="autoincrement id")
conv_uid = Column(String(255), unique=False, nullable=False, comment="Conversation record unique id")
__tablename__ = "chat_history"
id = Column(
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
)
conv_uid = Column(
String(255),
unique=False,
nullable=False,
comment="Conversation record unique id",
)
chat_mode = Column(String(255), nullable=False, comment="Conversation scene mode")
summary = Column(String(255), nullable=False, comment="Conversation record summary")
user_name = Column(String(255), nullable=True, comment="interlocutor")
messages = Column(Text, nullable=True, comment="Conversation details")
__table_args__ = (
UniqueConstraint('conv_uid', name="uk_conversation"),
Index('idx_q_user', 'user_name'),
Index('idx_q_mode', 'chat_mode'),
Index('idx_q_conv', 'summary'),
UniqueConstraint("conv_uid", name="uk_conversation"),
Index("idx_q_user", "user_name"),
Index("idx_q_mode", "chat_mode"),
Index("idx_q_conv", "summary"),
)
@@ -31,9 +39,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
if user_name:
chat_history = chat_history.filter(
ChatHistoryEntity.user_name == user_name
)
chat_history = chat_history.filter(ChatHistoryEntity.user_name == user_name)
chat_history = chat_history.order_by(ChatHistoryEntity.id.desc())
@@ -50,13 +56,11 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
finally:
session.close()
def update_message_by_uid(self, message: str, conv_uid:str):
def update_message_by_uid(self, message: str, conv_uid: str):
session = self.get_session()
try:
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
updated = chat_history.update({ChatHistoryEntity.messages: message})
session.commit()
return updated.id
@@ -69,9 +73,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
raise Exception("conv_uid is None")
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
chat_history.delete()
session.commit()
session.close()
@@ -79,10 +81,7 @@ class ChatHistoryDao(BaseDao[ChatHistoryEntity]):
def get_by_uid(self, conv_uid: str) -> ChatHistoryEntity:
session = self.get_session()
chat_history = session.query(ChatHistoryEntity)
chat_history = chat_history.filter(
ChatHistoryEntity.conv_uid == conv_uid
)
chat_history = chat_history.filter(ChatHistoryEntity.conv_uid == conv_uid)
result = chat_history.first()
session.close()
return result

View File

@@ -148,7 +148,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return json.loads(context[0])
return None
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
if os.path.isfile(duckdb_path):

View File

@@ -17,7 +17,7 @@ CFG = Config()
class FileHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.File.value
store_type: str = MemoryStoreType.File.value
def __init__(self, chat_session_id: str):
now = datetime.datetime.now()
@@ -49,5 +49,3 @@ class FileHistoryMemory(BaseChatHistoryMemory):
def clear(self) -> None:
self.file_path.write_text(json.dumps([]))

View File

@@ -10,7 +10,7 @@ CFG = Config()
class MemHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.Memory.value
store_type: str = MemoryStoreType.Memory.value
histroies_map = FixedSizeDict(100)

View File

@@ -12,18 +12,22 @@ from pilot.scene.message import (
from ..chat_history_db import ChatHistoryEntity, ChatHistoryDao
from pilot.memory.chat_history.base import MemoryStoreType
CFG = Config()
logger = logging.getLogger("db_chat_history")
class DbHistoryMemory(BaseChatHistoryMemory):
store_type: str = MemoryStoreType.DB.value
store_type: str = MemoryStoreType.DB.value
def __init__(self, chat_session_id: str):
self.chat_seesion_id = chat_session_id
self.chat_history_dao = ChatHistoryDao()
def messages(self) -> List[OnceConversation]:
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)
if chat_history:
context = chat_history.messages
if context:
@@ -31,7 +35,6 @@ class DbHistoryMemory(BaseChatHistoryMemory):
return conversations
return []
def create(self, chat_mode, summary: str, user_name: str) -> None:
try:
chat_history: ChatHistoryEntity = ChatHistoryEntity()
@@ -43,10 +46,11 @@ class DbHistoryMemory(BaseChatHistoryMemory):
except Exception as e:
logger.error("init create conversation log error" + str(e))
def append(self, once_message: OnceConversation) -> None:
logger.info("db history append:{}", once_message)
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
chat_history: ChatHistoryEntity = self.chat_history_dao.get_by_uid(
self.chat_seesion_id
)
conversations: List[OnceConversation] = []
if chat_history:
context = chat_history.messages
@@ -59,7 +63,7 @@ class DbHistoryMemory(BaseChatHistoryMemory):
chat_history.conv_uid = self.chat_seesion_id
chat_history.chat_mode = once_message.chat_mode
chat_history.user_name = "default"
chat_history.summary = once_message.get_user_conv().content
chat_history.summary = once_message.get_user_conv().content
conversations.append(_conversation_to_dic(once_message))
chat_history.messages = json.dumps(conversations, ensure_ascii=False)
@@ -67,31 +71,28 @@ class DbHistoryMemory(BaseChatHistoryMemory):
self.chat_history_dao.update(chat_history)
def update(self, messages: List[OnceConversation]) -> None:
self.chat_history_dao.update_message_by_uid(json.dumps(messages, ensure_ascii=False), self.chat_seesion_id)
self.chat_history_dao.update_message_by_uid(
json.dumps(messages, ensure_ascii=False), self.chat_seesion_id
)
def delete(self) -> bool:
self.chat_history_dao.delete(self.chat_seesion_id)
def conv_info(self, conv_uid: str = None) -> None:
logger.info("conv_info:{}", conv_uid)
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
chat_history = self.chat_history_dao.get_by_uid(conv_uid)
return chat_history.__dict__
def get_messages(self) -> List[OnceConversation]:
logger.info("get_messages:{}", self.chat_seesion_id)
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
chat_history = self.chat_history_dao.get_by_uid(self.chat_seesion_id)
if chat_history:
context = chat_history.messages
return json.loads(context)
return []
@staticmethod
def conv_list(cls, user_name: str = None) -> None:
chat_history_dao = ChatHistoryDao()
history_list = chat_history_dao.list_last_20()
result = []

View File

@@ -66,12 +66,14 @@ def run_migrations_online() -> None:
)
with connectable.connect() as connection:
if engine.dialect.name == 'sqlite':
context.configure(connection=engine.connect(), target_metadata=target_metadata, render_as_batch=True)
else:
if engine.dialect.name == "sqlite":
context.configure(
connection=connection, target_metadata=target_metadata
connection=engine.connect(),
target_metadata=target_metadata,
render_as_batch=True,
)
else:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()

View File

@@ -145,12 +145,12 @@ def initialize_controller(
controller.backend = LocalModelController()
if app:
app.include_router(router, prefix="/api", tags=['Model'])
app.include_router(router, prefix="/api", tags=["Model"])
else:
import uvicorn
app = FastAPI()
app.include_router(router, prefix="/api", tags=['Model'])
app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")

View File

@@ -9,18 +9,21 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
BAICHUAN_DEFAULT_MODEL = "Baichuan2-53B"
def _calculate_md5(text: str) -> str:
"""Calculate md5 """
"""Calculate md5"""
md5 = hashlib.md5()
md5.update(text.encode("utf-8"))
encrypted = md5.hexdigest()
return encrypted
def _sign(data: dict, secret_key: str, timestamp: str):
data_str = json.dumps(data)
signature = _calculate_md5(secret_key + data_str + timestamp)
return signature
def baichuan_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=4096
):
@@ -31,7 +34,6 @@ def baichuan_generate_stream(
proxy_api_key = model_params.proxy_api_key
proxy_api_secret = model_params.proxy_api_secret
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
@@ -47,11 +49,11 @@ def baichuan_generate_stream(
payload = {
"model": model_name,
"messages": history,
"messages": history,
"parameters": {
"temperature": params.get("temperature"),
"top_k": params.get("top_k", 10)
}
"top_k": params.get("top_k", 10),
},
}
timestamp = int(time.time())

View File

@@ -15,6 +15,7 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
SPARK_DEFAULT_API_VERSION = "v2"
def spark_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -31,7 +32,6 @@ def spark_generate_stream(
domain = "general"
url = "ws://spark-api.xf-yun.com/v1.1/chat"
messages: List[ModelMessage] = params["messages"]
history = []
@@ -57,43 +57,39 @@ def spark_generate_stream(
break
data = {
"header": {
"app_id": proxy_app_id,
"uid": params.get("request_id", 1)
},
"header": {"app_id": proxy_app_id, "uid": params.get("request_id", 1)},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": context_len,
"auditing": "default",
"temperature": params.get("temperature")
"temperature": params.get("temperature"),
}
},
"payload": {
"message": {
"text": last_user_input.get("content")
}
}
"payload": {"message": {"text": last_user_input.get("content")}},
}
async_call(request_url, data)
async def async_call(request_url, data):
async with websockets.connect(request_url) as ws:
await ws.send(json.dumps(data, ensure_ascii=False))
finish = False
while not finish:
chunk = ws.recv()
chunk = ws.recv()
response = json.loads(chunk)
if response.get("header", {}).get("status") == 2:
finish = True
if text := response.get("payload", {}).get("choices", {}).get("text"):
yield text[0]["content"]
class SparkAPI:
def __init__(self, appid: str, api_key: str, api_secret: str, spark_url: str) -> None:
class SparkAPI:
def __init__(
self, appid: str, api_key: str, api_secret: str, spark_url: str
) -> None:
self.appid = appid
self.api_key = api_key
self.api_secret = api_secret
@@ -102,9 +98,7 @@ class SparkAPI:
self.spark_url = spark_url
def gen_url(self):
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
@@ -112,21 +106,22 @@ class SparkAPI:
_signature += "data: " + date + "\n"
_signature += "GET " + self.path + " HTTP/1.1"
_signature_sha = hmac.new(self.api_secret.encode("utf-8"), _signature.encode("utf-8"),
digestmod=hashlib.sha256).digest()
_signature_sha = hmac.new(
self.api_secret.encode("utf-8"),
_signature.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(encoding="utf-8")
_signature_sha_base64 = base64.b64encode(_signature_sha).decode(
encoding="utf-8"
)
_authorization = f"api_key='{self.api_key}', algorithm='hmac-sha256', headers='host date request-line', signature='{_signature_sha_base64}'"
authorization = base64.b64encode(_authorization.encode('utf-8')).decode(encoding='utf-8')
authorization = base64.b64encode(_authorization.encode("utf-8")).decode(
encoding="utf-8"
)
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
v = {"authorization": authorization, "date": date, "host": self.host}
url = self.spark_url + "?" + urlencode(v)
return url

View File

@@ -8,10 +8,11 @@ logger = logging.getLogger(__name__)
def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import dashscope
from dashscope import Generation
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
@@ -62,14 +63,14 @@ def tongyi_generate_stream(
messages=history,
top_p=params.get("top_p", 0.8),
stream=True,
result_format='message'
result_format="message",
)
for r in res:
if r:
if r['status_code'] == 200:
if r["status_code"] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield content
else:
content = r['code'] + ":" + r["message"]
content = r["code"] + ":" + r["message"]
yield content

View File

@@ -6,20 +6,26 @@ from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from cachetools import cached, TTLCache
@cached(TTLCache(1, 1800))
def _build_access_token(api_key: str, secret_key: str) -> str:
"""
Generate Access token according AK, SK
Generate Access token according AK, SK
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
params = {
"grant_type": "client_credentials",
"client_id": api_key,
"client_secret": secret_key,
}
res = requests.get(url=url, params=params)
if res.status_code == 200:
return res.json().get("access_token")
def wenxin_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -38,12 +44,9 @@ def wenxin_generate_stream(
proxy_api_secret = model_params.proxy_api_secret
access_token = _build_access_token(proxy_api_key, proxy_api_secret)
headers = {
"Content-Type": "application/json",
"Accept": "application/json"
}
headers = {"Content-Type": "application/json", "Accept": "application/json"}
proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}"
if not access_token:
yield "Failed to get access token. please set the correct api_key and secret key."
@@ -86,7 +89,7 @@ def wenxin_generate_stream(
"messages": history,
"system": system,
"temperature": params.get("temperature"),
"stream": True
"stream": True,
}
text = ""
@@ -106,6 +109,3 @@ def wenxin_generate_stream(
content = obj["result"]
text += content
yield text

View File

@@ -7,6 +7,7 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CHATGLM_DEFAULT_MODEL = "chatglm_pro"
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
@@ -19,6 +20,7 @@ def zhipu_generate_stream(
proxyllm_backend = CHATGLM_DEFAULT_MODEL or model_params.proxyllm_backend
import zhipuai
zhipuai.api_key = proxy_api_key
history = []

View File

@@ -297,7 +297,6 @@ async def params_load(
@router.post("/v1/chat/dialogue/delete")
async def dialogue_delete(con_uid: str):
history_fac = ChatHistory()
history_mem = history_fac.get_store_instance(con_uid)
history_mem.delete()

View File

@@ -54,7 +54,7 @@ class BaseChat(ABC):
)
chat_history_fac = ChatHistory()
### can configurable storage methods
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
self.memory = chat_history_fac.get_store_instance(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(

View File

@@ -20,10 +20,11 @@ logger = logging.getLogger("chat_agent")
class ChatAgent(BaseChat):
chat_scene: str = ChatScene.ChatAgent.value()
chat_retention_rounds = 0
def __init__(self, chat_param: Dict):
if not chat_param['select_param']:
if not chat_param["select_param"]:
raise ValueError("Please select a Plugin!")
self.select_plugins = chat_param['select_param'].split(",")
self.select_plugins = chat_param["select_param"].split(",")
chat_param["chat_mode"] = ChatScene.ChatAgent
super().__init__(chat_param=chat_param)
@@ -31,8 +32,12 @@ class ChatAgent(BaseChat):
self.plugins_prompt_generator.command_registry = CFG.command_registry
# load select plugin
agent_module = CFG.SYSTEM_APP.get_component(ComponentType.AGENT_HUB, ModuleAgent)
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
agent_module = CFG.SYSTEM_APP.get_component(
ComponentType.AGENT_HUB, ModuleAgent
)
self.plugins_prompt_generator = agent_module.load_select_plugin(
self.plugins_prompt_generator, self.select_plugins
)
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
@@ -53,4 +58,3 @@ class ChatAgent(BaseChat):
def __list_to_prompt_str(self, list: List) -> str:
return "\n".join(f"{i + 1 + 1}. {item}" for i, item in enumerate(list))

View File

@@ -54,7 +54,7 @@ _DEFAULT_TEMPLATE = (
)
_PROMPT_SCENE_DEFINE=(
_PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
@@ -76,7 +76,7 @@ prompt = PromptTemplate(
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
temperature = 1
temperature=1
# example_selector=plugin_example,
)

View File

@@ -36,7 +36,7 @@ class ChatExcel(BaseChat):
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
self.api_call = ApiCall(display_registry = CFG.command_disply)
self.api_call = ApiCall(display_registry=CFG.command_disply)
super().__init__(chat_param=chat_param)
def _generate_numbered_list(self) -> str:

View File

@@ -44,7 +44,7 @@ _DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SCENE_DEFINE =(
PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)

View File

@@ -9,8 +9,21 @@ import pandas as pd
import chardet
import pandas as pd
import numpy as np
from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\
Literal, infixNotation, opAssoc, unicodeString,Regex
from pyparsing import (
CaselessKeyword,
Word,
alphas,
alphanums,
delimitedList,
Forward,
Group,
Optional,
Literal,
infixNotation,
opAssoc,
unicodeString,
Regex,
)
from pilot.common.pd_utils import csv_colunm_foramt
from pilot.common.string_utils import is_chinese_include_number
@@ -21,14 +34,15 @@ def excel_colunm_format(old_name: str) -> str:
new_column = new_column.replace(" ", "_")
return new_column
def detect_encoding(file_path):
# 读取文件的二进制数据
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = f.read()
# 使用 chardet 来检测文件编码
result = chardet.detect(data)
encoding = result['encoding']
confidence = result['confidence']
encoding = result["encoding"]
confidence = result["confidence"]
return encoding, confidence
@@ -39,10 +53,11 @@ def add_quotes_ex(sql: str, column_names):
sql = sql.replace(column_name, f'"{column_name}"')
return sql
def parse_sql(sql):
# 定义关键字和标识符
select_stmt = Forward()
column = Regex(r'[\w一-龥]*')
column = Regex(r"[\w一-龥]*")
table = Word(alphanums)
join_expr = Forward()
where_expr = Forward()
@@ -62,21 +77,24 @@ def parse_sql(sql):
not_in_keyword = CaselessKeyword("NOT IN")
# 定义语法规则
select_stmt <<= (select_keyword + delimitedList(column) +
from_keyword + delimitedList(table) +
Optional(join_expr) +
Optional(where_keyword + where_expr) +
Optional(group_by_keyword + group_by_expr) +
Optional(order_by_keyword + order_by_expr))
select_stmt <<= (
select_keyword
+ delimitedList(column)
+ from_keyword
+ delimitedList(table)
+ Optional(join_expr)
+ Optional(where_keyword + where_expr)
+ Optional(group_by_keyword + group_by_expr)
+ Optional(order_by_keyword + order_by_expr)
)
join_expr <<= join_keyword + table + on_keyword + column + Literal("=") + column
where_expr <<= column + Literal("=") + Word(alphanums) + \
Optional(and_keyword + where_expr) | \
column + Literal(">") + Word(alphanums) + \
Optional(and_keyword + where_expr) | \
column + Literal("<") + Word(alphanums) + \
Optional(and_keyword + where_expr)
where_expr <<= (
column + Literal("=") + Word(alphanums) + Optional(and_keyword + where_expr)
| column + Literal(">") + Word(alphanums) + Optional(and_keyword + where_expr)
| column + Literal("<") + Word(alphanums) + Optional(and_keyword + where_expr)
)
group_by_expr <<= delimitedList(column)
@@ -88,7 +106,6 @@ def parse_sql(sql):
return parsed_result.asList()
def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
sql = sql.replace("'", "")
@@ -108,6 +125,7 @@ def deep_quotes(token, column_names=[]):
new_value = token.value.replace("`", "").replace("'", "")
token.value = f'"{new_value}"'
def get_select_clause(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
@@ -123,6 +141,7 @@ def get_select_clause(sql):
select_tokens.append(token)
return "".join(str(token) for token in select_tokens)
def parse_select_fields(sql):
parsed = sqlparse.parse(sql)[0] # 解析 SQL 语句,获取第一个语句块
fields = []
@@ -139,12 +158,14 @@ def parse_select_fields(sql):
return fields
def add_quotes_to_chinese_columns(sql, column_names=[]):
parsed = sqlparse.parse(sql)
for stmt in parsed:
process_statement(stmt, column_names)
return str(parsed[0])
def process_statement(statement, column_names=[]):
if isinstance(statement, sqlparse.sql.IdentifierList):
for identifier in statement.get_identifiers():
@@ -155,22 +176,23 @@ def process_statement(statement, column_names=[]):
for item in statement.tokens:
process_statement(item)
def process_identifier(identifier, column_names=[]):
# if identifier.has_alias():
# alias = identifier.get_alias()
# identifier.tokens[-1].value = '[' + alias + ']'
if hasattr(identifier, 'tokens') and identifier.value in column_names:
if is_chinese(identifier.value):
if hasattr(identifier, "tokens") and identifier.value in column_names:
if is_chinese(identifier.value):
new_value = get_new_value(identifier.value)
identifier.value = new_value
identifier.normalized = new_value
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
else:
if hasattr(identifier, 'tokens'):
if hasattr(identifier, "tokens"):
for token in identifier.tokens:
if isinstance(token, sqlparse.sql.Function):
process_function(token)
elif token.ttype in sqlparse.tokens.Name :
elif token.ttype in sqlparse.tokens.Name:
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
@@ -179,9 +201,12 @@ def process_identifier(identifier, column_names=[]):
token.value = new_value
token.normalized = new_value
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
def get_new_value(value):
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
def process_function(function):
function_params = list(function.get_parameters())
# for param in function_params:
@@ -191,15 +216,18 @@ def process_function(function):
if isinstance(param, sqlparse.sql.Identifier):
# 判断是否需要替换字段值
# if is_chinese(param.value):
# 替换字段值
# 替换字段值
new_value = get_new_value(param.value)
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
function_params[i].tokens = [
sqlparse.sql.Token(sqlparse.tokens.Name, new_value)
]
print(str(function))
def is_chinese(text):
for char in text:
if '' <= char <= '鿿':
if "" <= char <= "鿿":
return True
return False
@@ -240,7 +268,7 @@ class ExcelReader:
df_tmp = pd.read_csv(file_path, encoding=encoding)
self.df = pd.read_csv(
file_path,
encoding = encoding,
encoding=encoding,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else:
@@ -280,7 +308,6 @@ class ExcelReader:
logging.error("excel sql run error!", e)
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql):
colunms, values = self.run(sql)
return pd.DataFrame(values, columns=colunms)

View File

@@ -16,7 +16,7 @@ class ChatWithPlugin(BaseChat):
select_plugin: str = None
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param["select_param"]
self.plugin_selector = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__(chat_param=chat_param)
self.plugins_prompt_generator = PluginPromptGenerator()

View File

@@ -15,7 +15,6 @@ ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__fi
sys.path.append(ROOT_PATH)
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
@@ -32,7 +31,6 @@ def async_db_summary(system_app: SystemApp):
def server_init(args, system_app: SystemApp):
from pilot.base_modules.agent.commands.command_mange import CommandRegistry
# logger.info(f"args: {args}")
# init config
@@ -44,8 +42,6 @@ def server_init(args, system_app: SystemApp):
# load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
# Loader plugins and commands
command_categories = [
"pilot.base_modules.agent.commands.built_in.audio_text",
@@ -126,4 +122,3 @@ class WebWerverParameters(BaseParameters):
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})

View File

@@ -29,6 +29,7 @@ def initialize_components(
system_app.register_instance(controller)
from pilot.base_modules.agent.controller import module_agent
system_app.register_instance(module_agent)
_initialize_embedding_model(

View File

@@ -31,7 +31,9 @@ from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.base_modules.agent.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.base_modules.agent.commands.disply_type.show_chart_gen import (
static_message_img_path,
)
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import (
setup_logging,
@@ -56,6 +58,8 @@ def swagger_monkey_patch(*args, **kwargs):
swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js",
swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css"
)
app = FastAPI()
applications.get_swagger_ui_html = swagger_monkey_patch
@@ -73,14 +77,14 @@ app.add_middleware(
)
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(api_v1, prefix="/api", tags=["Chat"])
app.include_router(api_editor_route_v1, prefix="/api", tags=["Editor"])
app.include_router(llm_manage_api, prefix="/api", tags=["LLM Manage"])
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])
app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
def mount_static_files(app):
@@ -98,6 +102,7 @@ def mount_static_files(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def _get_webserver_params(args: List[str] = None):
from pilot.utils.parameter_utils import EnvArgumentParser
@@ -106,6 +111,7 @@ def _get_webserver_params(args: List[str] = None):
)
return WebWerverParameters(**vars(parser.parse_args(args=args)))
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
"""Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.