From d372e73cd58bd36021bb10c9c44522f01babb820 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Sun, 25 Jun 2023 14:46:46 +0800 Subject: [PATCH] WEB API independent --- pilot/common/plugins.py | 14 +- pilot/common/schema.py | 4 + pilot/configs/config.py | 4 +- pilot/configs/model_config.py | 2 - pilot/connections/rdbms/py_study/pd_study.py | 40 +++-- .../connections/rdbms/py_study/test_cls_1.py | 12 ++ .../connections/rdbms/py_study/test_cls_2.py | 15 ++ .../rdbms/py_study/test_cls_base.py | 12 ++ pilot/model/adapter.py | 11 +- pilot/model/llm_out/chatglm_llm.py | 2 +- pilot/model/llm_utils.py | 2 +- pilot/out_parser/base.py | 2 +- pilot/prompts/example_base.py | 38 +++++ pilot/prompts/prompt_new.py | 4 +- pilot/prompts/prompt_template.py | 6 +- pilot/scene/base.py | 4 + pilot/scene/base_chat.py | 27 +--- pilot/scene/chat_dashboard/__init__.py | 0 .../business_cockpit/__init__.py | 0 pilot/scene/chat_execution/example.py | 9 ++ pilot/scene/chat_execution/prompt.py | 33 ++-- pilot/scene/chat_execution/prompt_v2.py | 1 + pilot/scene/chat_knowledge/custom/chat.py | 21 +-- pilot/scene/message.py | 21 ++- pilot/server/api_v1/api_v1.py | 146 ++++++++++++++++++ pilot/server/api_v1/api_view_model.py | 57 +++++++ pilot/server/llmserver.py | 12 +- pilot/server/webserver.py | 101 ++++++------ pilot/server/webserver_base.py | 60 +++++++ pilot/speech/eleven_labs.py | 2 +- pilot/vector_store/connector.py | 7 +- pilot/vector_store/weaviate_store.py | 146 ------------------ 32 files changed, 506 insertions(+), 309 deletions(-) create mode 100644 pilot/connections/rdbms/py_study/test_cls_1.py create mode 100644 pilot/connections/rdbms/py_study/test_cls_2.py create mode 100644 pilot/connections/rdbms/py_study/test_cls_base.py create mode 100644 pilot/prompts/example_base.py create mode 100644 pilot/scene/chat_dashboard/__init__.py create mode 100644 pilot/scene/chat_dashboard/business_cockpit/__init__.py create mode 100644 pilot/scene/chat_execution/example.py create mode 100644 pilot/scene/chat_execution/prompt_v2.py create mode 100644 pilot/server/api_v1/api_v1.py create mode 100644 pilot/server/api_v1/api_view_model.py create mode 100644 pilot/server/webserver_base.py delete mode 100644 pilot/vector_store/weaviate_store.py diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index e22224399..832144d22 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -78,7 +78,6 @@ def load_native_plugins(cfg: Config): if not cfg.plugins_auto_load: print("not auto load_native_plugins") return - def load_from_git(cfg: Config): print("async load_native_plugins") branch_name = cfg.plugins_git_branch @@ -86,20 +85,16 @@ def load_native_plugins(cfg: Config): url = "https://github.com/csunny/{repo}/archive/{branch}.zip" try: session = requests.Session() - response = session.get( - url.format(repo=native_plugin_repo, branch=branch_name), - headers={"Authorization": "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"}, - ) + response = session.get(url.format(repo=native_plugin_repo, branch=branch_name), + headers={'Authorization': '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'}) if response.status_code == 200: plugins_path_path = Path(PLUGINS_DIR) - files = glob.glob( - os.path.join(plugins_path_path, f"{native_plugin_repo}*") - ) + files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*')) for file in files: os.remove(file) now = datetime.datetime.now() - time_str = now.strftime("%Y%m%d%H%M%S") + time_str = now.strftime('%Y%m%d%H%M%S') file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" print(file_name) with open(file_name, "wb") as f: @@ -115,6 +110,7 @@ def load_native_plugins(cfg: Config): t.start() + def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: """Scan the plugins directory for plugins and loads them. diff --git a/pilot/common/schema.py b/pilot/common/schema.py index cd462966c..dc7a6c7cc 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -7,3 +7,7 @@ class SeparatorStyle(Enum): TWO = "" THREE = auto() FOUR = auto() + +class ExampleType(Enum): + ONE_SHOT = "one_shot" + FEW_SHOT = "few_shot" diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 594b8b4ae..94ff19e21 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -90,7 +90,7 @@ class Config(metaclass=Singleton): ### The associated configuration parameters of the plug-in control the loading and use of the plug-in self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins_openai = [] - self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" + self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") @@ -150,8 +150,6 @@ class Config(metaclass=Singleton): self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) - self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080") - # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index f6d59e4e1..0dc78af06 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -43,8 +43,6 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), - # TODO Support baichuan-7b - # "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"), "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py index 5a2b3edae..ffe62a50a 100644 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ b/pilot/connections/rdbms/py_study/pd_study.py @@ -6,6 +6,8 @@ import numpy as np from matplotlib.font_manager import FontProperties from pyecharts.charts import Bar from pyecharts import options as opts +from test_cls_1 import TestBase,Test1 +from test_cls_2 import Test2 CFG = Config() @@ -56,20 +58,30 @@ CFG = Config() # + if __name__ == "__main__": - def __extract_json(s): - i = s.index("{") - count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): - if c == "}": - count -= 1 - elif c == "{": - count += 1 - if count == 0: - break - assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + # def __extract_json(s): + # i = s.index("{") + # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 + # for j, c in enumerate(s[i + 1 :], start=i + 1): + # if c == "}": + # count -= 1 + # elif c == "{": + # count += 1 + # if count == 0: + # break + # assert count == 0 # 检查是否找到最后一个'}' + # return s[i : j + 1] + # + # ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" + # print(__extract_json(ss)) - ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" - print(__extract_json(ss)) + + test1 = Test1() + test2 = Test2() + test1.write() + test1.test() + test2.write() + test1.test() + test2.test() diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py new file mode 100644 index 000000000..66c07de78 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_1.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from test_cls_base import TestBase + + +class Test1(TestBase): + + def write(self): + self.test_values.append("x") + self.test_values.append("y") + self.test_values.append("g") + diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py new file mode 100644 index 000000000..c0fdbb305 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_2.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from test_cls_base import TestBase +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +class Test2(TestBase): + test_2_values:List = [] + + def write(self): + self.test_values.append(1) + self.test_values.append(2) + self.test_values.append(3) + self.test_2_values.append("x") + self.test_2_values.append("y") + self.test_2_values.append("z") \ No newline at end of file diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py new file mode 100644 index 000000000..9a04a48b3 --- /dev/null +++ b/pilot/connections/rdbms/py_study/test_cls_base.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + + +class TestBase(BaseModel, ABC): + test_values: List = [] + + + def test(self): + print(self.__class__.__name__ + ":" ) + print(self.test_values) \ No newline at end of file diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 9ea80fb7a..2c1089cec 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -32,14 +32,9 @@ class BaseLLMAdaper: return True def loader(self, model_path: str, from_pretrained_kwargs: dict): - tokenizer = AutoTokenizer.from_pretrained( - model_path, use_fast=False, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) model = AutoModelForCausalLM.from_pretrained( - model_path, - low_cpu_mem_usage=True, - trust_remote_code=True, - **from_pretrained_kwargs, + model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs ) return model, tokenizer @@ -62,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper: raise ValueError(f"Invalid model adapter for {model_path}") -# TODO support cpu? for practice we support gpt4all or chatglm-6b-int4? +# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? class VicunaLLMAdapater(BaseLLMAdaper): diff --git a/pilot/model/llm_out/chatglm_llm.py b/pilot/model/llm_out/chatglm_llm.py index 690be0f06..dc8522fcc 100644 --- a/pilot/model/llm_out/chatglm_llm.py +++ b/pilot/model/llm_out/chatglm_llm.py @@ -11,7 +11,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER def chatglm_generate_stream( model, tokenizer, params, device, context_len=2048, stream_interval=2 ): - """Generate text using chatglm model's chat api""" + """Generate text using chatglm model's chat api_v1""" prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 359d478f8..0dfd8d2b5 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -51,7 +51,7 @@ def create_chat_completion( return message response = None - # TODO impl this use vicuna server api + # TODO impl this use vicuna server api_v1 class Stream(transformers.StoppingCriteria): diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 96ddf64a0..4476a68fc 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -32,7 +32,7 @@ class BaseOutputParser(ABC): Output parsers help structure language model responses. """ - def __init__(self, sep: str, is_stream_out: bool): + def __init__(self, sep: str, is_stream_out: bool = True): self.sep = sep self.is_stream_out = is_stream_out diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py new file mode 100644 index 000000000..4d876aa51 --- /dev/null +++ b/pilot/prompts/example_base.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union + +from pilot.common.schema import ExampleType + +class ExampleSelector(BaseModel, ABC): + examples: List[List] + use_example: bool = False + type: str = ExampleType.ONE_SHOT.value + + def examples(self, count: int = 2): + if ExampleType.ONE_SHOT.value == self.type: + return self.__one_show_context() + else: + return self.__few_shot_context(count) + + def __few_shot_context(self, count: int = 2) -> List[List]: + """ + Use 2 or more examples, default 2 + Returns: example text + """ + if self.use_example: + need_use = self.examples[:count] + return need_use + return None + + def __one_show_context(self) -> List: + """ + Use one examples + Returns: + + """ + if self.use_example: + need_use = self.examples[:1] + return need_use + + return None diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 6f50895fa..65e107d11 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from pilot.common.formatting import formatter from pilot.out_parser.base import BaseOutputParser from pilot.common.schema import SeparatorStyle - +from pilot.prompts.example_base import ExampleSelector def jinja2_formatter(template: str, **kwargs: Any) -> str: """Format a template using jinja2.""" @@ -32,7 +32,6 @@ class PromptTemplate(BaseModel, ABC): input_variables: List[str] """A list of the names of the variables the prompt template expects.""" template_scene: Optional[str] - template_define: Optional[str] """this template define""" template: Optional[str] @@ -46,6 +45,7 @@ class PromptTemplate(BaseModel, ABC): output_parser: BaseOutputParser = None """""" sep: str = SeparatorStyle.SINGLE.value + example: ExampleSelector = None class Config: """Configuration for this pydantic object.""" diff --git a/pilot/prompts/prompt_template.py b/pilot/prompts/prompt_template.py index 0a014b06e..53f66a9da 100644 --- a/pilot/prompts/prompt_template.py +++ b/pilot/prompts/prompt_template.py @@ -182,7 +182,7 @@ class BasePromptTemplate(BaseModel, ABC): Example: .. code-block:: python - prompt.save(file_path="path/prompt.yaml") + prompt.save(file_path="path/prompt.api_v1") """ if self.partial_variables: raise ValueError("Cannot save prompt with partial variables.") @@ -201,11 +201,11 @@ class BasePromptTemplate(BaseModel, ABC): if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(prompt_dict, f, indent=4) - elif save_path.suffix == ".yaml": + elif save_path.suffix == ".api_v1": with open(file_path, "w") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: - raise ValueError(f"{save_path} must be json or yaml") + raise ValueError(f"{save_path} must be json or api_v1") class StringPromptValue(PromptValue): diff --git a/pilot/scene/base.py b/pilot/scene/base.py index e301a14de..7b75d353e 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -10,3 +10,7 @@ class ChatScene(Enum): ChatUrlKnowledge = "chat_url_knowledge" InnerChatDBSummary = "inner_chat_db_summary" ChatNormal = "chat_normal" + + @staticmethod + def is_valid_mode(mode): + return any(mode == item.value for item in ChatScene) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 0120b9e86..572227966 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -70,7 +70,7 @@ class BaseChat(ABC): self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL ### can configurable storage methods - self.memory = MemHistoryMemory(chat_session_id) + self.memory = FileHistoryMemory(chat_session_id) ### load prompt template self.prompt_template: PromptTemplate = CFG.prompt_templates[ @@ -139,9 +139,7 @@ class BaseChat(ABC): self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") - ai_response_text = "" try: - show_info = "" response = requests.post( urljoin(CFG.MODEL_SERVER, "generate_stream"), headers=headers, @@ -157,16 +155,10 @@ class BaseChat(ABC): # show_info = resp_text_trunck # yield resp_text_trunck + "▌" - self.current_message.add_ai_message(show_info) - except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) - self.current_message.add_view_message( - f"""ERROR!{str(e)}\n {ai_response_text} """ - ) - ### 对话记录存储 - self.memory.append(self.current_message) + raise ValueError(str(e)) def nostream_call(self): payload = self.__call_base() @@ -188,20 +180,6 @@ class BaseChat(ABC): ) ) - # ### MOCK - # ai_response_text = """{ - # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", - # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", - # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", - # "command": { - # "name": "histogram-executor", - # "args": { - # "title": "订单城市分布柱状图", - # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" - # } - # } - # }""" - self.current_message.add_ai_message(ai_response_text) prompt_define_response = ( self.prompt_template.output_parser.parse_prompt_response( @@ -293,6 +271,7 @@ class BaseChat(ABC): return text + # 暂时为了兼容前端 def current_ai_response(self) -> str: for message in self.current_message.messages: diff --git a/pilot/scene/chat_dashboard/__init__.py b/pilot/scene/chat_dashboard/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_dashboard/business_cockpit/__init__.py b/pilot/scene/chat_dashboard/business_cockpit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py new file mode 100644 index 000000000..6cd71b39c --- /dev/null +++ b/pilot/scene/chat_execution/example.py @@ -0,0 +1,9 @@ +from pilot.prompts.example_base import ExampleSelector + +## Two examples are defined by default +EXAMPLES = [ + [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], + [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] +] + +example = ExampleSelector(examples=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index eebb8de94..b5bc38bb2 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -1,36 +1,31 @@ import json -import importlib from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.common.schema import SeparatorStyle +from pilot.common.schema import SeparatorStyle, ExampleType from pilot.scene.chat_execution.out_parser import PluginChatOutputParser - +from pilot.scene.chat_execution.example import example CFG = Config() -PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" +# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" +PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." -PROMPT_SUFFIX = """ + +_DEFAULT_TEMPLATE = """ Goals: {input} -""" - -_DEFAULT_TEMPLATE = """ Constraints: 0.Exclusively use the commands listed in double quotes e.g. "command name" {constraints} Commands: {commands_infos} -""" - -PROMPT_RESPONSE = """ Please response strictly according to the following json format: - {response} +{response} Ensure the response is correct json and can be parsed by Python json.loads """ @@ -40,20 +35,22 @@ RESPONSE_FORMAT = { "command": {"name": "command name", "args": {"arg name": "value"}}, } + + +EXAMPLE_TYPE = ExampleType.ONE_SHOT PROMPT_SEP = SeparatorStyle.SINGLE.value ### Whether the model service is streaming output -PROMPT_NEED_NEED_STREAM_OUT = False +PROMPT_NEED_STREAM_OUT = False prompt = PromptTemplate( template_scene=ChatScene.ChatExecution.value, input_variables=["input", "constraints", "commands_infos", "response"], response_format=json.dumps(RESPONSE_FORMAT, indent=4), template_define=PROMPT_SCENE_DEFINE, - template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, - stream_out=PROMPT_NEED_NEED_STREAM_OUT, - output_parser=PluginChatOutputParser( - sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT - ), + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_STREAM_OUT, + output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT), + example=example ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_execution/prompt_v2.py b/pilot/scene/chat_execution/prompt_v2.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/pilot/scene/chat_execution/prompt_v2.py @@ -0,0 +1 @@ + diff --git a/pilot/scene/chat_knowledge/custom/chat.py b/pilot/scene/chat_knowledge/custom/chat.py index 214bf1656..a56b2a098 100644 --- a/pilot/scene/chat_knowledge/custom/chat.py +++ b/pilot/scene/chat_knowledge/custom/chat.py @@ -1,5 +1,3 @@ -from chromadb.errors import NoIndexException - from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -52,19 +50,12 @@ class ChatNewKnowledge(BaseChat): ) def generate_input_values(self): - try: - docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE - ) - context = [d.page_content for d in docs] - self.metadata = [d.metadata for d in docs] - context = context[:2000] - input_values = {"context": context, "question": self.current_user_input} - except NoIndexException: - raise ValueError( - f"you have no {self.knowledge_name} knowledge store, please upload your knowledge" - ) - + docs = self.knowledge_embedding_client.similar_search( + self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE + ) + context = [d.page_content for d in docs] + context = context[:2000] + input_values = {"context": context, "question": self.current_user_input} return input_values def do_with_prompt_response(self, prompt_response): diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 0203ec68c..c19ea6112 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -43,12 +43,29 @@ class OnceConversation: def add_ai_message(self, message: str) -> None: """Add an AI message to the store""" + has_message = any(isinstance(instance, AIMessage) for instance in self.messages) if has_message: - raise ValueError("Already Have Ai message") - self.messages.append(AIMessage(content=message)) + self.update_ai_message(message) + else: + self.messages.append(AIMessage(content=message)) """ """ + + def __update_ai_message(self, new_message:str)-> None: + """ + stream out message update + Args: + new_message: + + Returns: + + """ + + for item in self.messages: + if item.type == "ai": + item.content = new_message + def add_view_message(self, message: str) -> None: """Add an AI message to the store""" diff --git a/pilot/server/api_v1/api_v1.py b/pilot/server/api_v1/api_v1.py new file mode 100644 index 000000000..19f4e765c --- /dev/null +++ b/pilot/server/api_v1/api_v1.py @@ -0,0 +1,146 @@ +import uuid + +from fastapi import APIRouter, Request, Body, status + +from fastapi.responses import JSONResponse +from fastapi.responses import StreamingResponse +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from typing import List + +from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo +from pilot.configs.config import Config +from pilot.scene.base_chat import BaseChat +from pilot.scene.base import ChatScene +from pilot.scene.chat_factory import ChatFactory +from pilot.configs.model_config import (LOGDIR) +from pilot.utils import build_logger +from pilot.scene.base_message import (BaseMessage) + +router = APIRouter() +CFG = Config() +CHAT_FACTORY = ChatFactory() +logger = build_logger("api_v1", LOGDIR + "api_v1.log") + + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + message = "" + for error in exc.errors(): + message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";" + return Result.faild(message) + + +@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]]) +async def dialogue_list(user_id: str): + #### TODO + + conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"), + ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")] + + return Result[ConversationVo].succ(conversations) + + +@router.post('/v1/chat/dialogue/new', response_model=Result[str]) +async def dialogue_new(user_id: str): + unique_id = uuid.uuid1() + return Result.succ(unique_id) + + +@router.post('/v1/chat/dialogue/delete') +async def dialogue_delete(con_uid: str, user_id: str): + #### TODO + return Result.succ(None) + + +@router.post('/v1/chat/completions', response_model=Result[MessageVo]) +async def chat_completions(dialogue: ConversationVo = Body()): + print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + + if not ChatScene.is_valid_mode(dialogue.chat_mode): + raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) + + chat_param = { + "chat_session_id": dialogue.conv_uid, + "user_input": dialogue.user_input, + } + + if ChatScene.ChatWithDbExecute == dialogue.chat_mode: + chat_param.update("db_name", dialogue.select_param) + elif ChatScene.ChatWithDbQA == dialogue.chat_mode: + chat_param.update("db_name", dialogue.select_param) + elif ChatScene.ChatExecution == dialogue.chat_mode: + chat_param.update("plugin_selector", dialogue.select_param) + elif ChatScene.ChatNewKnowledge == dialogue.chat_mode: + chat_param.update("knowledge_name", dialogue.select_param) + elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode: + chat_param.update("url", dialogue.select_param) + + chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + if not chat.prompt_template.stream_out: + return non_stream_response(chat) + else: + return stream_response(chat) + + +def stream_generator(chat): + model_response = chat.stream_call() + for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + messageVos = [message2Vo(element) for element in chat.current_message.messages] + yield Result.succ(messageVos) +def stream_response(chat): + logger.info("stream out start!") + api_response = StreamingResponse(stream_generator(chat), media_type="application/json") + return api_response + +def message2Vo(message:BaseMessage)->MessageVo: + vo:MessageVo = MessageVo() + vo.role = message.type + vo.role = message.content + vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0 + +def non_stream_response(chat): + logger.info("not stream out, wait model response!") + chat.nostream_call() + messageVos = [message2Vo(element) for element in chat.current_message.messages] + return Result.succ(messageVos) + + +@router.get('/v1/db/types', response_model=Result[str]) +async def db_types(): + return Result.succ(["mysql", "duckdb"]) + + +@router.get('/v1/db/list', response_model=Result[str]) +async def db_list(): + db = CFG.local_db + dbs = db.get_database_list() + return Result.succ(dbs) + + +@router.get('/v1/knowledge/list') +async def knowledge_list(): + return ["test1", "test2"] + + +@router.post('/v1/knowledge/add') +async def knowledge_add(): + return ["test1", "test2"] + + +@router.post('/v1/knowledge/delete') +async def knowledge_delete(): + return ["test1", "test2"] + + +@router.get('/v1/knowledge/types') +async def knowledge_types(): + return ["test1", "test2"] + + +@router.get('/v1/knowledge/detail') +async def knowledge_detail(): + return ["test1", "test2"] diff --git a/pilot/server/api_v1/api_view_model.py b/pilot/server/api_v1/api_view_model.py new file mode 100644 index 000000000..938ce22ec --- /dev/null +++ b/pilot/server/api_v1/api_view_model.py @@ -0,0 +1,57 @@ +from pydantic import BaseModel, Field +from typing import TypeVar, Union, List, Generic + +T = TypeVar('T') + + +class Result(Generic[T], BaseModel): + success: bool + err_code: str + err_msg: str + data: List[T] + + @classmethod + def succ(cls, data: List[T]): + return Result(True, None, None, data) + + @classmethod + def faild(cls, msg): + return Result(True, "E000X", msg, None) + + @classmethod + def faild(cls, code, msg): + return Result(True, code, msg, None) + + +class ConversationVo(BaseModel): + """ + dialogue_uid + """ + conv_uid: str = Field(..., description="dialogue uid") + """ + user input + """ + user_input: str + """ + the scene of chat + """ + chat_mode: str = Field(..., description="the scene of chat ") + """ + chat scene select param + """ + select_param: str + + +class MessageVo(BaseModel): + """ + role that sends out the current message + """ + role: str + """ + current message + """ + context: str + """ + time the current message was sent + """ + time_stamp: float diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 51339f322..1e3a4dcb3 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -5,7 +5,6 @@ import asyncio import json import os import sys -import traceback import uvicorn from fastapi import BackgroundTasks, FastAPI, Request @@ -76,10 +75,8 @@ class ModelWorker: ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, - # and opening it may affect the frontend output - if not ("gptj" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL): - print("output: ", output) - + # and opening it may affect the frontend output. + # print("output: ", output) ret = { "text": output, "error_code": 0, @@ -90,13 +87,10 @@ class ModelWorker: ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} yield json.dumps(ret).encode() + b"\0" except Exception as e: - msg = "{}: {}".format(str(e), traceback.format_exc()) - ret = { - "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}", + "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", "error_code": 0, } - yield json.dumps(ret).encode() + b"\0" def get_embeddings(self, prompt): diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c7a033336..564a0c620 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import signal import threading import traceback import argparse @@ -12,12 +11,10 @@ import uuid import gradio as gr - ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) from pilot.summary.db_summary_client import DBSummaryClient -from pilot.commands.command_mange import CommandRegistry from pilot.scene.base_chat import BaseChat @@ -25,8 +22,8 @@ from pilot.configs.config import Config from pilot.configs.model_config import ( DATASETS_DIR, KNOWLEDGE_UPLOAD_ROOT_PATH, - LOGDIR, LLM_MODEL_CONFIG, + LOGDIR, ) from pilot.conversation import ( @@ -35,7 +32,6 @@ from pilot.conversation import ( chat_mode_title, default_conversation, ) -from pilot.common.plugins import scan_plugins, load_native_plugins from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_patch import Chatbot as grChatbot @@ -49,6 +45,19 @@ from pilot.vector_store.extract_tovec import ( from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory from pilot.language.translation_handler import get_lang_text +from pilot.server.webserver_base import server_init + + +import uvicorn +from fastapi import BackgroundTasks, Request +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from fastapi import FastAPI, applications +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.exceptions import RequestValidationError +from fastapi.staticfiles import StaticFiles + +from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler # 加载插件 CFG = Config() @@ -95,6 +104,19 @@ knowledge_qa_type_list = [ add_knowledge_base_dialogue, ] +def swagger_monkey_patch(*args, **kwargs): + return get_swagger_ui_html( + *args, **kwargs, + swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', + swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + ) +applications.get_swagger_ui_html = swagger_monkey_patch + +app = FastAPI() +# app.mount("static", StaticFiles(directory="static"), name="static") +app.include_router(api_v1) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + def get_simlar(q): docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) @@ -324,15 +346,14 @@ def http_bot( response = chat.stream_call() for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - state.messages[-1][ - -1 - ] = chat.prompt_template.output_parser.parse_model_stream_resp_ex( - chunk, chat.skip_echo_len - ) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + state.messages[-1][-1] =msg + chat.current_message.add_ai_message(msg) yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 + chat.memory.append(chat.current_message) except Exception as e: print(traceback.format_exc()) - state.messages[-1][-1] = "Error:" + str(e) + state.messages[-1][-1] = f"""ERROR!{str(e)} """ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 @@ -632,7 +653,7 @@ def knowledge_embedding_store(vs_id, files): ) knowledge_embedding_client = KnowledgeEmbedding( file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), - model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], + model_name=LLM_MODEL_CONFIG["text2vec"], vector_store_config={ "vector_store_name": vector_store_name["vs_name"], "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, @@ -657,48 +678,36 @@ def signal_handler(sig, frame): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument('-new', '--new', action='store_true', help='enable new http mode') + + # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--concurrency-count", type=int, default=10) - parser.add_argument( - "--model-list-mode", type=str, default="once", choices=["once", "reload"] - ) parser.add_argument("--share", default=False, action="store_true") + + # init server config args = parser.parse_args() - logger.info(f"args: {args}") + server_init(args) + + if args.new: + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) + else: + ### Compatibility mode starts the old version server by default + demo = build_webdemo() + demo.queue( + concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False + ).launch( + server_name=args.host, + server_port=args.port, + share=args.share, + max_threads=200, + ) - # init config - cfg = Config() - load_native_plugins(cfg) - dbs = cfg.local_db.get_database_list() - signal.signal(signal.SIGINT, signal_handler) - async_db_summery() - cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) - # Loader plugins and commands - command_categories = [ - "pilot.commands.built_in.audio_text", - "pilot.commands.built_in.image_gen", - ] - # exclude commands - command_categories = [ - x for x in command_categories if x not in cfg.disabled_command_categories - ] - command_registry = CommandRegistry() - for command_category in command_categories: - command_registry.import_commands(command_category) - cfg.command_registry = command_registry - logger.info(args) - demo = build_webdemo() - demo.queue( - concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False - ).launch( - server_name=args.host, - server_port=args.port, - share=args.share, - max_threads=200, - ) diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py new file mode 100644 index 000000000..0aa2ac3f9 --- /dev/null +++ b/pilot/server/webserver_base.py @@ -0,0 +1,60 @@ +import signal +import os +import threading +import traceback +import sys + +from pilot.summary.db_summary_client import DBSummaryClient +from pilot.commands.command_mange import CommandRegistry +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) +from pilot.common.plugins import scan_plugins, load_native_plugins +from pilot.utils import build_logger + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +logger = build_logger("webserver", LOGDIR + "webserver.log") + + +def signal_handler(sig, frame): + print("in order to avoid chroma db atexit problem") + os._exit(0) + + +def async_db_summery(): + client = DBSummaryClient() + thread = threading.Thread(target=client.init_db_summary) + thread.start() + + +def server_init(args): + logger.info(f"args: {args}") + + # init config + cfg = Config() + + load_native_plugins(cfg) + signal.signal(signal.SIGINT, signal_handler) + async_db_summery() + cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode)) + + # Loader plugins and commands + command_categories = [ + "pilot.commands.built_in.audio_text", + "pilot.commands.built_in.image_gen", + ] + # exclude commands + command_categories = [ + x for x in command_categories if x not in cfg.disabled_command_categories + ] + command_registry = CommandRegistry() + for command_category in command_categories: + command_registry.import_commands(command_category) + + cfg.command_registry = command_registry diff --git a/pilot/speech/eleven_labs.py b/pilot/speech/eleven_labs.py index c461c7ad7..dad841517 100644 --- a/pilot/speech/eleven_labs.py +++ b/pilot/speech/eleven_labs.py @@ -35,7 +35,7 @@ class ElevenLabsSpeech(VoiceBase): } self._headers = { "Content-Type": "application/json", - "xi-api-key": cfg.elevenlabs_api_key, + "xi-api_v1-key": cfg.elevenlabs_api_key, } self._voices = default_voices.copy() if cfg.elevenlabs_voice_1_id in voice_options: diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py index 8ba6df253..482f43007 100644 --- a/pilot/vector_store/connector.py +++ b/pilot/vector_store/connector.py @@ -1,13 +1,12 @@ from pilot.vector_store.chroma_store import ChromaStore -from pilot.vector_store.milvus_store import MilvusStore -from pilot.vector_store.weaviate_store import WeaviateStore +# from pilot.vector_store.milvus_store import MilvusStore -connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore} +connector = {"Chroma": ChromaStore, "Milvus": None} class VectorStoreConnector: - """vector store connector, can connect different vector db provided load document api and similar search api.""" + """vector store connector, can connect different vector db provided load document api_v1 and similar search api_v1.""" def __init__(self, vector_store_type, ctx: {}) -> None: """initialize vector store connector.""" diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py deleted file mode 100644 index fc5455672..000000000 --- a/pilot/vector_store/weaviate_store.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -import json -import weaviate -from langchain.schema import Document -from langchain.vectorstores import Weaviate -from weaviate.exceptions import WeaviateBaseError - -from pilot.configs.config import Config -from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH -from pilot.logs import logger -from pilot.vector_store.vector_store_base import VectorStoreBase - -CFG = Config() - - -class WeaviateStore(VectorStoreBase): - """Weaviate database""" - - def __init__(self, ctx: dict) -> None: - """Initialize with Weaviate client.""" - try: - import weaviate - except ImportError: - raise ValueError( - "Could not import weaviate python package. " - "Please install it with `pip install weaviate-client`." - ) - - self.ctx = ctx - self.weaviate_url = CFG.WEAVIATE_URL - self.embedding = ctx.get("embeddings", None) - self.vector_name = ctx["vector_store_name"] - self.persist_dir = os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb" - ) - - self.vector_store_client = weaviate.Client(self.weaviate_url) - - def similar_search(self, text: str, topk: int) -> None: - """Perform similar search in Weaviate""" - logger.info("Weaviate similar search") - # nearText = { - # "concepts": [text], - # "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance" - # } - # vector = self.embedding.embed_query(text) - response = ( - self.vector_store_client.query.get( - self.vector_name, ["metadata", "page_content"] - ) - # .with_near_vector({"vector": vector}) - .with_limit(topk).do() - ) - res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]] - docs = [] - for r in res: - docs.append( - Document( - page_content=r["page_content"], - metadata={"metadata": r["metadata"]}, - ) - ) - return docs - - def vector_name_exists(self) -> bool: - """Check if a vector name exists for a given class in Weaviate. - Returns: - bool: True if the vector name exists, False otherwise. - """ - try: - if self.vector_store_client.schema.get(self.vector_name): - return True - return False - except WeaviateBaseError as e: - logger.error("vector_name_exists error", e.message) - return False - - def _default_schema(self) -> None: - """ - Create the schema for Weaviate with a Document class containing metadata and text properties. - """ - - schema = { - "classes": [ - { - "class": self.vector_name, - "description": "A document with metadata and text", - # "moduleConfig": { - # "text2vec-transformers": { - # "poolingStrategy": "masked_mean", - # "vectorizeClassName": False, - # } - # }, - "properties": [ - { - "dataType": ["text"], - # "moduleConfig": { - # "text2vec-transformers": { - # "skip": False, - # "vectorizePropertyName": False, - # } - # }, - "description": "Metadata of the document", - "name": "metadata", - }, - { - "dataType": ["text"], - # "moduleConfig": { - # "text2vec-transformers": { - # "skip": False, - # "vectorizePropertyName": False, - # } - # }, - "description": "Text content of the document", - "name": "page_content", - }, - ], - # "vectorizer": "text2vec-transformers", - } - ] - } - - # Create the schema in Weaviate - self.vector_store_client.schema.create(schema) - - def load_document(self, documents: list) -> None: - """Load documents into Weaviate""" - logger.info("Weaviate load document") - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - - # Import data - with self.vector_store_client.batch as batch: - batch.batch_size = 100 - - # Batch import all documents - for i in range(len(texts)): - properties = { - "metadata": metadatas[i]["source"], - "page_content": texts[i], - } - - self.vector_store_client.batch.add_data_object( - data_object=properties, class_name=self.vector_name - ) - self.vector_store_client.batch.flush()