mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
WEB API independent
This commit is contained in:
parent
0558a8ba37
commit
d372e73cd5
@ -78,7 +78,6 @@ def load_native_plugins(cfg: Config):
|
|||||||
if not cfg.plugins_auto_load:
|
if not cfg.plugins_auto_load:
|
||||||
print("not auto load_native_plugins")
|
print("not auto load_native_plugins")
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_from_git(cfg: Config):
|
def load_from_git(cfg: Config):
|
||||||
print("async load_native_plugins")
|
print("async load_native_plugins")
|
||||||
branch_name = cfg.plugins_git_branch
|
branch_name = cfg.plugins_git_branch
|
||||||
@ -86,20 +85,16 @@ def load_native_plugins(cfg: Config):
|
|||||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||||
try:
|
try:
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
response = session.get(
|
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
headers={'Authorization': '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'})
|
||||||
headers={"Authorization": "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"},
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
plugins_path_path = Path(PLUGINS_DIR)
|
plugins_path_path = Path(PLUGINS_DIR)
|
||||||
files = glob.glob(
|
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
|
||||||
)
|
|
||||||
for file in files:
|
for file in files:
|
||||||
os.remove(file)
|
os.remove(file)
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||||
print(file_name)
|
print(file_name)
|
||||||
with open(file_name, "wb") as f:
|
with open(file_name, "wb") as f:
|
||||||
@ -115,6 +110,7 @@ def load_native_plugins(cfg: Config):
|
|||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||||
"""Scan the plugins directory for plugins and loads them.
|
"""Scan the plugins directory for plugins and loads them.
|
||||||
|
|
||||||
|
@ -7,3 +7,7 @@ class SeparatorStyle(Enum):
|
|||||||
TWO = "</s>"
|
TWO = "</s>"
|
||||||
THREE = auto()
|
THREE = auto()
|
||||||
FOUR = auto()
|
FOUR = auto()
|
||||||
|
|
||||||
|
class ExampleType(Enum):
|
||||||
|
ONE_SHOT = "one_shot"
|
||||||
|
FEW_SHOT = "few_shot"
|
||||||
|
@ -90,7 +90,7 @@ class Config(metaclass=Singleton):
|
|||||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||||
self.plugins_openai = []
|
self.plugins_openai = []
|
||||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||||
|
|
||||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||||
|
|
||||||
@ -150,8 +150,6 @@ class Config(metaclass=Singleton):
|
|||||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||||
|
|
||||||
self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080")
|
|
||||||
|
|
||||||
# QLoRA
|
# QLoRA
|
||||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||||
|
|
||||||
|
@ -43,8 +43,6 @@ LLM_MODEL_CONFIG = {
|
|||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||||
# TODO Support baichuan-7b
|
|
||||||
# "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"),
|
|
||||||
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
||||||
"proxyllm": "proxyllm",
|
"proxyllm": "proxyllm",
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,8 @@ import numpy as np
|
|||||||
from matplotlib.font_manager import FontProperties
|
from matplotlib.font_manager import FontProperties
|
||||||
from pyecharts.charts import Bar
|
from pyecharts.charts import Bar
|
||||||
from pyecharts import options as opts
|
from pyecharts import options as opts
|
||||||
|
from test_cls_1 import TestBase,Test1
|
||||||
|
from test_cls_2 import Test2
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -56,20 +58,30 @@ CFG = Config()
|
|||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
def __extract_json(s):
|
# def __extract_json(s):
|
||||||
i = s.index("{")
|
# i = s.index("{")
|
||||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||||
if c == "}":
|
# if c == "}":
|
||||||
count -= 1
|
# count -= 1
|
||||||
elif c == "{":
|
# elif c == "{":
|
||||||
count += 1
|
# count += 1
|
||||||
if count == 0:
|
# if count == 0:
|
||||||
break
|
# break
|
||||||
assert count == 0 # 检查是否找到最后一个'}'
|
# assert count == 0 # 检查是否找到最后一个'}'
|
||||||
return s[i : j + 1]
|
# return s[i : j + 1]
|
||||||
|
#
|
||||||
|
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||||
|
# print(__extract_json(ss))
|
||||||
|
|
||||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
|
||||||
print(__extract_json(ss))
|
test1 = Test1()
|
||||||
|
test2 = Test2()
|
||||||
|
test1.write()
|
||||||
|
test1.test()
|
||||||
|
test2.write()
|
||||||
|
test1.test()
|
||||||
|
test2.test()
|
||||||
|
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from test_cls_base import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class Test1(TestBase):
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.test_values.append("x")
|
||||||
|
self.test_values.append("y")
|
||||||
|
self.test_values.append("g")
|
||||||
|
|
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from test_cls_base import TestBase
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
class Test2(TestBase):
|
||||||
|
test_2_values:List = []
|
||||||
|
|
||||||
|
def write(self):
|
||||||
|
self.test_values.append(1)
|
||||||
|
self.test_values.append(2)
|
||||||
|
self.test_values.append(3)
|
||||||
|
self.test_2_values.append("x")
|
||||||
|
self.test_2_values.append("y")
|
||||||
|
self.test_2_values.append("z")
|
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
|
||||||
|
class TestBase(BaseModel, ABC):
|
||||||
|
test_values: List = []
|
||||||
|
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
print(self.__class__.__name__ + ":" )
|
||||||
|
print(self.test_values)
|
@ -32,14 +32,9 @@ class BaseLLMAdaper:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
model_path, use_fast=False, trust_remote_code=True
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
**from_pretrained_kwargs,
|
|
||||||
)
|
)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -62,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
|
|||||||
raise ValueError(f"Invalid model adapter for {model_path}")
|
raise ValueError(f"Invalid model adapter for {model_path}")
|
||||||
|
|
||||||
|
|
||||||
# TODO support cpu? for practice we support gpt4all or chatglm-6b-int4?
|
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||||
|
|
||||||
|
|
||||||
class VicunaLLMAdapater(BaseLLMAdaper):
|
class VicunaLLMAdapater(BaseLLMAdaper):
|
||||||
|
@ -11,7 +11,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
|||||||
def chatglm_generate_stream(
|
def chatglm_generate_stream(
|
||||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||||
):
|
):
|
||||||
"""Generate text using chatglm model's chat api"""
|
"""Generate text using chatglm model's chat api_v1"""
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
temperature = float(params.get("temperature", 1.0))
|
temperature = float(params.get("temperature", 1.0))
|
||||||
top_p = float(params.get("top_p", 1.0))
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
@ -51,7 +51,7 @@ def create_chat_completion(
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
# TODO impl this use vicuna server api
|
# TODO impl this use vicuna server api_v1
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
|
@ -32,7 +32,7 @@ class BaseOutputParser(ABC):
|
|||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sep: str, is_stream_out: bool):
|
def __init__(self, sep: str, is_stream_out: bool = True):
|
||||||
self.sep = sep
|
self.sep = sep
|
||||||
self.is_stream_out = is_stream_out
|
self.is_stream_out = is_stream_out
|
||||||
|
|
||||||
|
38
pilot/prompts/example_base.py
Normal file
38
pilot/prompts/example_base.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
|
from pilot.common.schema import ExampleType
|
||||||
|
|
||||||
|
class ExampleSelector(BaseModel, ABC):
|
||||||
|
examples: List[List]
|
||||||
|
use_example: bool = False
|
||||||
|
type: str = ExampleType.ONE_SHOT.value
|
||||||
|
|
||||||
|
def examples(self, count: int = 2):
|
||||||
|
if ExampleType.ONE_SHOT.value == self.type:
|
||||||
|
return self.__one_show_context()
|
||||||
|
else:
|
||||||
|
return self.__few_shot_context(count)
|
||||||
|
|
||||||
|
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||||
|
"""
|
||||||
|
Use 2 or more examples, default 2
|
||||||
|
Returns: example text
|
||||||
|
"""
|
||||||
|
if self.use_example:
|
||||||
|
need_use = self.examples[:count]
|
||||||
|
return need_use
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __one_show_context(self) -> List:
|
||||||
|
"""
|
||||||
|
Use one examples
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.use_example:
|
||||||
|
need_use = self.examples[:1]
|
||||||
|
return need_use
|
||||||
|
|
||||||
|
return None
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
|||||||
from pilot.common.formatting import formatter
|
from pilot.common.formatting import formatter
|
||||||
from pilot.out_parser.base import BaseOutputParser
|
from pilot.out_parser.base import BaseOutputParser
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
|
||||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||||
"""Format a template using jinja2."""
|
"""Format a template using jinja2."""
|
||||||
@ -32,7 +32,6 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
template_scene: Optional[str]
|
template_scene: Optional[str]
|
||||||
|
|
||||||
template_define: Optional[str]
|
template_define: Optional[str]
|
||||||
"""this template define"""
|
"""this template define"""
|
||||||
template: Optional[str]
|
template: Optional[str]
|
||||||
@ -46,6 +45,7 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
output_parser: BaseOutputParser = None
|
output_parser: BaseOutputParser = None
|
||||||
""""""
|
""""""
|
||||||
sep: str = SeparatorStyle.SINGLE.value
|
sep: str = SeparatorStyle.SINGLE.value
|
||||||
|
example: ExampleSelector = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
@ -182,7 +182,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
prompt.save(file_path="path/prompt.yaml")
|
prompt.save(file_path="path/prompt.api_v1")
|
||||||
"""
|
"""
|
||||||
if self.partial_variables:
|
if self.partial_variables:
|
||||||
raise ValueError("Cannot save prompt with partial variables.")
|
raise ValueError("Cannot save prompt with partial variables.")
|
||||||
@ -201,11 +201,11 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
if save_path.suffix == ".json":
|
if save_path.suffix == ".json":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
json.dump(prompt_dict, f, indent=4)
|
json.dump(prompt_dict, f, indent=4)
|
||||||
elif save_path.suffix == ".yaml":
|
elif save_path.suffix == ".api_v1":
|
||||||
with open(file_path, "w") as f:
|
with open(file_path, "w") as f:
|
||||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{save_path} must be json or yaml")
|
raise ValueError(f"{save_path} must be json or api_v1")
|
||||||
|
|
||||||
|
|
||||||
class StringPromptValue(PromptValue):
|
class StringPromptValue(PromptValue):
|
||||||
|
@ -10,3 +10,7 @@ class ChatScene(Enum):
|
|||||||
ChatUrlKnowledge = "chat_url_knowledge"
|
ChatUrlKnowledge = "chat_url_knowledge"
|
||||||
InnerChatDBSummary = "inner_chat_db_summary"
|
InnerChatDBSummary = "inner_chat_db_summary"
|
||||||
ChatNormal = "chat_normal"
|
ChatNormal = "chat_normal"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_mode(mode):
|
||||||
|
return any(mode == item.value for item in ChatScene)
|
||||||
|
@ -70,7 +70,7 @@ class BaseChat(ABC):
|
|||||||
self.current_user_input: str = current_user_input
|
self.current_user_input: str = current_user_input
|
||||||
self.llm_model = CFG.LLM_MODEL
|
self.llm_model = CFG.LLM_MODEL
|
||||||
### can configurable storage methods
|
### can configurable storage methods
|
||||||
self.memory = MemHistoryMemory(chat_session_id)
|
self.memory = FileHistoryMemory(chat_session_id)
|
||||||
|
|
||||||
### load prompt template
|
### load prompt template
|
||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
@ -139,9 +139,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
|
||||||
try:
|
try:
|
||||||
show_info = ""
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@ -157,16 +155,10 @@ class BaseChat(ABC):
|
|||||||
# show_info = resp_text_trunck
|
# show_info = resp_text_trunck
|
||||||
# yield resp_text_trunck + "▌"
|
# yield resp_text_trunck + "▌"
|
||||||
|
|
||||||
self.current_message.add_ai_message(show_info)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
logger.error("model response parase faild!" + str(e))
|
logger.error("model response parase faild!" + str(e))
|
||||||
self.current_message.add_view_message(
|
raise ValueError(str(e))
|
||||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
|
||||||
)
|
|
||||||
### 对话记录存储
|
|
||||||
self.memory.append(self.current_message)
|
|
||||||
|
|
||||||
def nostream_call(self):
|
def nostream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
@ -188,20 +180,6 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# ### MOCK
|
|
||||||
# ai_response_text = """{
|
|
||||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
|
||||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
|
||||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
|
||||||
# "command": {
|
|
||||||
# "name": "histogram-executor",
|
|
||||||
# "args": {
|
|
||||||
# "title": "订单城市分布柱状图",
|
|
||||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# }"""
|
|
||||||
|
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = (
|
prompt_define_response = (
|
||||||
self.prompt_template.output_parser.parse_prompt_response(
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
@ -293,6 +271,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
# 暂时为了兼容前端
|
# 暂时为了兼容前端
|
||||||
def current_ai_response(self) -> str:
|
def current_ai_response(self) -> str:
|
||||||
for message in self.current_message.messages:
|
for message in self.current_message.messages:
|
||||||
|
0
pilot/scene/chat_dashboard/__init__.py
Normal file
0
pilot/scene/chat_dashboard/__init__.py
Normal file
9
pilot/scene/chat_execution/example.py
Normal file
9
pilot/scene/chat_execution/example.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
|
|
||||||
|
## Two examples are defined by default
|
||||||
|
EXAMPLES = [
|
||||||
|
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
||||||
|
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
||||||
|
]
|
||||||
|
|
||||||
|
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
@ -1,36 +1,31 @@
|
|||||||
import json
|
import json
|
||||||
import importlib
|
|
||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||||
|
|
||||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||||
|
from pilot.scene.chat_execution.example import example
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||||
|
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||||
|
|
||||||
PROMPT_SUFFIX = """
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
Goals:
|
Goals:
|
||||||
{input}
|
{input}
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
|
||||||
Constraints:
|
Constraints:
|
||||||
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
||||||
{constraints}
|
{constraints}
|
||||||
|
|
||||||
Commands:
|
Commands:
|
||||||
{commands_infos}
|
{commands_infos}
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_RESPONSE = """
|
|
||||||
Please response strictly according to the following json format:
|
Please response strictly according to the following json format:
|
||||||
{response}
|
{response}
|
||||||
Ensure the response is correct json and can be parsed by Python json.loads
|
Ensure the response is correct json and can be parsed by Python json.loads
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -40,20 +35,22 @@ RESPONSE_FORMAT = {
|
|||||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
### Whether the model service is streaming output
|
### Whether the model service is streaming output
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
PROMPT_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatExecution.value,
|
template_scene=ChatScene.ChatExecution.value,
|
||||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||||
output_parser=PluginChatOutputParser(
|
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
example=example
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||||
|
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
@ -1,5 +1,3 @@
|
|||||||
from chromadb.errors import NoIndexException
|
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -52,19 +50,12 @@ class ChatNewKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
try:
|
docs = self.knowledge_embedding_client.similar_search(
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
)
|
||||||
)
|
context = [d.page_content for d in docs]
|
||||||
context = [d.page_content for d in docs]
|
context = context[:2000]
|
||||||
self.metadata = [d.metadata for d in docs]
|
input_values = {"context": context, "question": self.current_user_input}
|
||||||
context = context[:2000]
|
|
||||||
input_values = {"context": context, "question": self.current_user_input}
|
|
||||||
except NoIndexException:
|
|
||||||
raise ValueError(
|
|
||||||
f"you have no {self.knowledge_name} knowledge store, please upload your knowledge"
|
|
||||||
)
|
|
||||||
|
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_with_prompt_response(self, prompt_response):
|
def do_with_prompt_response(self, prompt_response):
|
||||||
|
@ -43,12 +43,29 @@ class OnceConversation:
|
|||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
def add_ai_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
|
||||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||||
if has_message:
|
if has_message:
|
||||||
raise ValueError("Already Have Ai message")
|
self.update_ai_message(message)
|
||||||
self.messages.append(AIMessage(content=message))
|
else:
|
||||||
|
self.messages.append(AIMessage(content=message))
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
|
|
||||||
|
def __update_ai_message(self, new_message:str)-> None:
|
||||||
|
"""
|
||||||
|
stream out message update
|
||||||
|
Args:
|
||||||
|
new_message:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
for item in self.messages:
|
||||||
|
if item.type == "ai":
|
||||||
|
item.content = new_message
|
||||||
|
|
||||||
def add_view_message(self, message: str) -> None:
|
def add_view_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
|
||||||
|
146
pilot/server/api_v1/api_v1.py
Normal file
146
pilot/server/api_v1/api_v1.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, Body, status
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
|
from pilot.configs.model_config import (LOGDIR)
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
from pilot.scene.base_message import (BaseMessage)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
CFG = Config()
|
||||||
|
CHAT_FACTORY = ChatFactory()
|
||||||
|
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
message = ""
|
||||||
|
for error in exc.errors():
|
||||||
|
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||||
|
return Result.faild(message)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
|
||||||
|
async def dialogue_list(user_id: str):
|
||||||
|
#### TODO
|
||||||
|
|
||||||
|
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
|
||||||
|
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
|
||||||
|
|
||||||
|
return Result[ConversationVo].succ(conversations)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
|
||||||
|
async def dialogue_new(user_id: str):
|
||||||
|
unique_id = uuid.uuid1()
|
||||||
|
return Result.succ(unique_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/dialogue/delete')
|
||||||
|
async def dialogue_delete(con_uid: str, user_id: str):
|
||||||
|
#### TODO
|
||||||
|
return Result.succ(None)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
|
||||||
|
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||||
|
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||||
|
|
||||||
|
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||||
|
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||||
|
|
||||||
|
chat_param = {
|
||||||
|
"chat_session_id": dialogue.conv_uid,
|
||||||
|
"user_input": dialogue.user_input,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
|
||||||
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
|
||||||
|
chat_param.update("db_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||||
|
chat_param.update("plugin_selector", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
|
||||||
|
chat_param.update("knowledge_name", dialogue.select_param)
|
||||||
|
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
|
||||||
|
chat_param.update("url", dialogue.select_param)
|
||||||
|
|
||||||
|
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||||
|
if not chat.prompt_template.stream_out:
|
||||||
|
return non_stream_response(chat)
|
||||||
|
else:
|
||||||
|
return stream_response(chat)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_generator(chat):
|
||||||
|
model_response = chat.stream_call()
|
||||||
|
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
|
if chunk:
|
||||||
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
|
chat.current_message.add_ai_message(msg)
|
||||||
|
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||||
|
yield Result.succ(messageVos)
|
||||||
|
def stream_response(chat):
|
||||||
|
logger.info("stream out start!")
|
||||||
|
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||||
|
return api_response
|
||||||
|
|
||||||
|
def message2Vo(message:BaseMessage)->MessageVo:
|
||||||
|
vo:MessageVo = MessageVo()
|
||||||
|
vo.role = message.type
|
||||||
|
vo.role = message.content
|
||||||
|
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
|
||||||
|
|
||||||
|
def non_stream_response(chat):
|
||||||
|
logger.info("not stream out, wait model response!")
|
||||||
|
chat.nostream_call()
|
||||||
|
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||||
|
return Result.succ(messageVos)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/db/types', response_model=Result[str])
|
||||||
|
async def db_types():
|
||||||
|
return Result.succ(["mysql", "duckdb"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/db/list', response_model=Result[str])
|
||||||
|
async def db_list():
|
||||||
|
db = CFG.local_db
|
||||||
|
dbs = db.get_database_list()
|
||||||
|
return Result.succ(dbs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/list')
|
||||||
|
async def knowledge_list():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/knowledge/add')
|
||||||
|
async def knowledge_add():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/v1/knowledge/delete')
|
||||||
|
async def knowledge_delete():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/types')
|
||||||
|
async def knowledge_types():
|
||||||
|
return ["test1", "test2"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/v1/knowledge/detail')
|
||||||
|
async def knowledge_detail():
|
||||||
|
return ["test1", "test2"]
|
57
pilot/server/api_v1/api_view_model.py
Normal file
57
pilot/server/api_v1/api_view_model.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import TypeVar, Union, List, Generic
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class Result(Generic[T], BaseModel):
|
||||||
|
success: bool
|
||||||
|
err_code: str
|
||||||
|
err_msg: str
|
||||||
|
data: List[T]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def succ(cls, data: List[T]):
|
||||||
|
return Result(True, None, None, data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def faild(cls, msg):
|
||||||
|
return Result(True, "E000X", msg, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def faild(cls, code, msg):
|
||||||
|
return Result(True, code, msg, None)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVo(BaseModel):
|
||||||
|
"""
|
||||||
|
dialogue_uid
|
||||||
|
"""
|
||||||
|
conv_uid: str = Field(..., description="dialogue uid")
|
||||||
|
"""
|
||||||
|
user input
|
||||||
|
"""
|
||||||
|
user_input: str
|
||||||
|
"""
|
||||||
|
the scene of chat
|
||||||
|
"""
|
||||||
|
chat_mode: str = Field(..., description="the scene of chat ")
|
||||||
|
"""
|
||||||
|
chat scene select param
|
||||||
|
"""
|
||||||
|
select_param: str
|
||||||
|
|
||||||
|
|
||||||
|
class MessageVo(BaseModel):
|
||||||
|
"""
|
||||||
|
role that sends out the current message
|
||||||
|
"""
|
||||||
|
role: str
|
||||||
|
"""
|
||||||
|
current message
|
||||||
|
"""
|
||||||
|
context: str
|
||||||
|
"""
|
||||||
|
time the current message was sent
|
||||||
|
"""
|
||||||
|
time_stamp: float
|
@ -5,7 +5,6 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import BackgroundTasks, FastAPI, Request
|
from fastapi import BackgroundTasks, FastAPI, Request
|
||||||
@ -76,10 +75,8 @@ class ModelWorker:
|
|||||||
):
|
):
|
||||||
# Please do not open the output in production!
|
# Please do not open the output in production!
|
||||||
# The gpt4all thread shares stdout with the parent process,
|
# The gpt4all thread shares stdout with the parent process,
|
||||||
# and opening it may affect the frontend output
|
# and opening it may affect the frontend output.
|
||||||
if not ("gptj" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL):
|
# print("output: ", output)
|
||||||
print("output: ", output)
|
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"text": output,
|
"text": output,
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
@ -90,13 +87,10 @@ class ModelWorker:
|
|||||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = "{}: {}".format(str(e), traceback.format_exc())
|
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}",
|
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield json.dumps(ret).encode() + b"\0"
|
yield json.dumps(ret).encode() + b"\0"
|
||||||
|
|
||||||
def get_embeddings(self, prompt):
|
def get_embeddings(self, prompt):
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import signal
|
|
||||||
import threading
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import argparse
|
import argparse
|
||||||
@ -12,12 +11,10 @@ import uuid
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(ROOT_PATH)
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
from pilot.summary.db_summary_client import DBSummaryClient
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
from pilot.commands.command_mange import CommandRegistry
|
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
|
|
||||||
@ -25,8 +22,8 @@ from pilot.configs.config import Config
|
|||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
DATASETS_DIR,
|
DATASETS_DIR,
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
LOGDIR,
|
|
||||||
LLM_MODEL_CONFIG,
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.conversation import (
|
from pilot.conversation import (
|
||||||
@ -35,7 +32,6 @@ from pilot.conversation import (
|
|||||||
chat_mode_title,
|
chat_mode_title,
|
||||||
default_conversation,
|
default_conversation,
|
||||||
)
|
)
|
||||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
|
||||||
|
|
||||||
from pilot.server.gradio_css import code_highlight_css
|
from pilot.server.gradio_css import code_highlight_css
|
||||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||||
@ -49,6 +45,19 @@ from pilot.vector_store.extract_tovec import (
|
|||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.language.translation_handler import get_lang_text
|
from pilot.language.translation_handler import get_lang_text
|
||||||
|
from pilot.server.webserver_base import server_init
|
||||||
|
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import BackgroundTasks, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from fastapi import FastAPI, applications
|
||||||
|
from fastapi.openapi.docs import get_swagger_ui_html
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -95,6 +104,19 @@ knowledge_qa_type_list = [
|
|||||||
add_knowledge_base_dialogue,
|
add_knowledge_base_dialogue,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def swagger_monkey_patch(*args, **kwargs):
|
||||||
|
return get_swagger_ui_html(
|
||||||
|
*args, **kwargs,
|
||||||
|
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||||
|
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||||
|
)
|
||||||
|
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||||
|
app.include_router(api_v1)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
|
||||||
def get_simlar(q):
|
def get_simlar(q):
|
||||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||||
@ -324,15 +346,14 @@ def http_bot(
|
|||||||
response = chat.stream_call()
|
response = chat.stream_call()
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
state.messages[-1][
|
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||||
-1
|
state.messages[-1][-1] =msg
|
||||||
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
chat.current_message.add_ai_message(msg)
|
||||||
chunk, chat.skip_echo_len
|
|
||||||
)
|
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
chat.memory.append(chat.current_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
state.messages[-1][-1] = "Error:" + str(e)
|
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||||
|
|
||||||
|
|
||||||
@ -632,7 +653,7 @@ def knowledge_embedding_store(vs_id, files):
|
|||||||
)
|
)
|
||||||
knowledge_embedding_client = KnowledgeEmbedding(
|
knowledge_embedding_client = KnowledgeEmbedding(
|
||||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||||
vector_store_config={
|
vector_store_config={
|
||||||
"vector_store_name": vector_store_name["vs_name"],
|
"vector_store_name": vector_store_name["vs_name"],
|
||||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
@ -657,48 +678,36 @@ def signal_handler(sig, frame):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||||
|
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||||
|
|
||||||
|
# old version server config
|
||||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||||
parser.add_argument(
|
|
||||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
|
||||||
)
|
|
||||||
parser.add_argument("--share", default=False, action="store_true")
|
parser.add_argument("--share", default=False, action="store_true")
|
||||||
|
|
||||||
|
|
||||||
|
# init server config
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info(f"args: {args}")
|
server_init(args)
|
||||||
|
|
||||||
|
if args.new:
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||||
|
else:
|
||||||
|
### Compatibility mode starts the old version server by default
|
||||||
|
demo = build_webdemo()
|
||||||
|
demo.queue(
|
||||||
|
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||||
|
).launch(
|
||||||
|
server_name=args.host,
|
||||||
|
server_port=args.port,
|
||||||
|
share=args.share,
|
||||||
|
max_threads=200,
|
||||||
|
)
|
||||||
|
|
||||||
# init config
|
|
||||||
cfg = Config()
|
|
||||||
|
|
||||||
load_native_plugins(cfg)
|
|
||||||
dbs = cfg.local_db.get_database_list()
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
async_db_summery()
|
|
||||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
|
||||||
|
|
||||||
# Loader plugins and commands
|
|
||||||
command_categories = [
|
|
||||||
"pilot.commands.built_in.audio_text",
|
|
||||||
"pilot.commands.built_in.image_gen",
|
|
||||||
]
|
|
||||||
# exclude commands
|
|
||||||
command_categories = [
|
|
||||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
|
||||||
]
|
|
||||||
command_registry = CommandRegistry()
|
|
||||||
for command_category in command_categories:
|
|
||||||
command_registry.import_commands(command_category)
|
|
||||||
|
|
||||||
cfg.command_registry = command_registry
|
|
||||||
|
|
||||||
logger.info(args)
|
|
||||||
demo = build_webdemo()
|
|
||||||
demo.queue(
|
|
||||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
|
||||||
).launch(
|
|
||||||
server_name=args.host,
|
|
||||||
server_port=args.port,
|
|
||||||
share=args.share,
|
|
||||||
max_threads=200,
|
|
||||||
)
|
|
||||||
|
60
pilot/server/webserver_base.py
Normal file
60
pilot/server/webserver_base.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import signal
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from pilot.summary.db_summary_client import DBSummaryClient
|
||||||
|
from pilot.commands.command_mange import CommandRegistry
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.configs.model_config import (
|
||||||
|
DATASETS_DIR,
|
||||||
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
|
LLM_MODEL_CONFIG,
|
||||||
|
LOGDIR,
|
||||||
|
)
|
||||||
|
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||||
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(ROOT_PATH)
|
||||||
|
|
||||||
|
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("in order to avoid chroma db atexit problem")
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def async_db_summery():
|
||||||
|
client = DBSummaryClient()
|
||||||
|
thread = threading.Thread(target=client.init_db_summary)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
def server_init(args):
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
|
||||||
|
# init config
|
||||||
|
cfg = Config()
|
||||||
|
|
||||||
|
load_native_plugins(cfg)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
async_db_summery()
|
||||||
|
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||||
|
|
||||||
|
# Loader plugins and commands
|
||||||
|
command_categories = [
|
||||||
|
"pilot.commands.built_in.audio_text",
|
||||||
|
"pilot.commands.built_in.image_gen",
|
||||||
|
]
|
||||||
|
# exclude commands
|
||||||
|
command_categories = [
|
||||||
|
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||||
|
]
|
||||||
|
command_registry = CommandRegistry()
|
||||||
|
for command_category in command_categories:
|
||||||
|
command_registry.import_commands(command_category)
|
||||||
|
|
||||||
|
cfg.command_registry = command_registry
|
@ -35,7 +35,7 @@ class ElevenLabsSpeech(VoiceBase):
|
|||||||
}
|
}
|
||||||
self._headers = {
|
self._headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"xi-api-key": cfg.elevenlabs_api_key,
|
"xi-api_v1-key": cfg.elevenlabs_api_key,
|
||||||
}
|
}
|
||||||
self._voices = default_voices.copy()
|
self._voices = default_voices.copy()
|
||||||
if cfg.elevenlabs_voice_1_id in voice_options:
|
if cfg.elevenlabs_voice_1_id in voice_options:
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from pilot.vector_store.chroma_store import ChromaStore
|
from pilot.vector_store.chroma_store import ChromaStore
|
||||||
|
|
||||||
from pilot.vector_store.milvus_store import MilvusStore
|
# from pilot.vector_store.milvus_store import MilvusStore
|
||||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
|
||||||
|
|
||||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
|
connector = {"Chroma": ChromaStore, "Milvus": None}
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConnector:
|
class VectorStoreConnector:
|
||||||
"""vector store connector, can connect different vector db provided load document api and similar search api."""
|
"""vector store connector, can connect different vector db provided load document api_v1 and similar search api_v1."""
|
||||||
|
|
||||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||||
"""initialize vector store connector."""
|
"""initialize vector store connector."""
|
||||||
|
@ -1,146 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import weaviate
|
|
||||||
from langchain.schema import Document
|
|
||||||
from langchain.vectorstores import Weaviate
|
|
||||||
from weaviate.exceptions import WeaviateBaseError
|
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
|
||||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
|
||||||
|
|
||||||
CFG = Config()
|
|
||||||
|
|
||||||
|
|
||||||
class WeaviateStore(VectorStoreBase):
|
|
||||||
"""Weaviate database"""
|
|
||||||
|
|
||||||
def __init__(self, ctx: dict) -> None:
|
|
||||||
"""Initialize with Weaviate client."""
|
|
||||||
try:
|
|
||||||
import weaviate
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import weaviate python package. "
|
|
||||||
"Please install it with `pip install weaviate-client`."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ctx = ctx
|
|
||||||
self.weaviate_url = CFG.WEAVIATE_URL
|
|
||||||
self.embedding = ctx.get("embeddings", None)
|
|
||||||
self.vector_name = ctx["vector_store_name"]
|
|
||||||
self.persist_dir = os.path.join(
|
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
|
||||||
|
|
||||||
def similar_search(self, text: str, topk: int) -> None:
|
|
||||||
"""Perform similar search in Weaviate"""
|
|
||||||
logger.info("Weaviate similar search")
|
|
||||||
# nearText = {
|
|
||||||
# "concepts": [text],
|
|
||||||
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
|
|
||||||
# }
|
|
||||||
# vector = self.embedding.embed_query(text)
|
|
||||||
response = (
|
|
||||||
self.vector_store_client.query.get(
|
|
||||||
self.vector_name, ["metadata", "page_content"]
|
|
||||||
)
|
|
||||||
# .with_near_vector({"vector": vector})
|
|
||||||
.with_limit(topk).do()
|
|
||||||
)
|
|
||||||
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
|
|
||||||
docs = []
|
|
||||||
for r in res:
|
|
||||||
docs.append(
|
|
||||||
Document(
|
|
||||||
page_content=r["page_content"],
|
|
||||||
metadata={"metadata": r["metadata"]},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def vector_name_exists(self) -> bool:
|
|
||||||
"""Check if a vector name exists for a given class in Weaviate.
|
|
||||||
Returns:
|
|
||||||
bool: True if the vector name exists, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if self.vector_store_client.schema.get(self.vector_name):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
except WeaviateBaseError as e:
|
|
||||||
logger.error("vector_name_exists error", e.message)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _default_schema(self) -> None:
|
|
||||||
"""
|
|
||||||
Create the schema for Weaviate with a Document class containing metadata and text properties.
|
|
||||||
"""
|
|
||||||
|
|
||||||
schema = {
|
|
||||||
"classes": [
|
|
||||||
{
|
|
||||||
"class": self.vector_name,
|
|
||||||
"description": "A document with metadata and text",
|
|
||||||
# "moduleConfig": {
|
|
||||||
# "text2vec-transformers": {
|
|
||||||
# "poolingStrategy": "masked_mean",
|
|
||||||
# "vectorizeClassName": False,
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
"properties": [
|
|
||||||
{
|
|
||||||
"dataType": ["text"],
|
|
||||||
# "moduleConfig": {
|
|
||||||
# "text2vec-transformers": {
|
|
||||||
# "skip": False,
|
|
||||||
# "vectorizePropertyName": False,
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
"description": "Metadata of the document",
|
|
||||||
"name": "metadata",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dataType": ["text"],
|
|
||||||
# "moduleConfig": {
|
|
||||||
# "text2vec-transformers": {
|
|
||||||
# "skip": False,
|
|
||||||
# "vectorizePropertyName": False,
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
"description": "Text content of the document",
|
|
||||||
"name": "page_content",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
# "vectorizer": "text2vec-transformers",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create the schema in Weaviate
|
|
||||||
self.vector_store_client.schema.create(schema)
|
|
||||||
|
|
||||||
def load_document(self, documents: list) -> None:
|
|
||||||
"""Load documents into Weaviate"""
|
|
||||||
logger.info("Weaviate load document")
|
|
||||||
texts = [doc.page_content for doc in documents]
|
|
||||||
metadatas = [doc.metadata for doc in documents]
|
|
||||||
|
|
||||||
# Import data
|
|
||||||
with self.vector_store_client.batch as batch:
|
|
||||||
batch.batch_size = 100
|
|
||||||
|
|
||||||
# Batch import all documents
|
|
||||||
for i in range(len(texts)):
|
|
||||||
properties = {
|
|
||||||
"metadata": metadatas[i]["source"],
|
|
||||||
"page_content": texts[i],
|
|
||||||
}
|
|
||||||
|
|
||||||
self.vector_store_client.batch.add_data_object(
|
|
||||||
data_object=properties, class_name=self.vector_name
|
|
||||||
)
|
|
||||||
self.vector_store_client.batch.flush()
|
|
Loading…
Reference in New Issue
Block a user