WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-25 14:46:46 +08:00
parent 0558a8ba37
commit d372e73cd5
32 changed files with 506 additions and 309 deletions

View File

@ -78,7 +78,6 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load: if not cfg.plugins_auto_load:
print("not auto load_native_plugins") print("not auto load_native_plugins")
return return
def load_from_git(cfg: Config): def load_from_git(cfg: Config):
print("async load_native_plugins") print("async load_native_plugins")
branch_name = cfg.plugins_git_branch branch_name = cfg.plugins_git_branch
@ -86,20 +85,16 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip" url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try: try:
session = requests.Session() session = requests.Session()
response = session.get( response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
url.format(repo=native_plugin_repo, branch=branch_name), headers={'Authorization': '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'})
headers={"Authorization": "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"},
)
if response.status_code == 200: if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR) plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob( files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
for file in files: for file in files:
os.remove(file) os.remove(file)
now = datetime.datetime.now() now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S") time_str = now.strftime('%Y%m%d%H%M%S')
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip" file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name) print(file_name)
with open(file_name, "wb") as f: with open(file_name, "wb") as f:
@ -115,6 +110,7 @@ def load_native_plugins(cfg: Config):
t.start() t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]: def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them. """Scan the plugins directory for plugins and loads them.

View File

@ -7,3 +7,7 @@ class SeparatorStyle(Enum):
TWO = "</s>" TWO = "</s>"
THREE = auto() THREE = auto()
FOUR = auto() FOUR = auto()
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"

View File

@ -90,7 +90,7 @@ class Config(metaclass=Singleton):
### The associated configuration parameters of the plug-in control the loading and use of the plug-in ### The associated configuration parameters of the plug-in control the loading and use of the plug-in
self.plugins: List[AutoGPTPluginTemplate] = [] self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins_openai = [] self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
@ -150,8 +150,6 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None) self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None) self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080")
# QLoRA # QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")

View File

@ -43,8 +43,6 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
# TODO Support baichuan-7b
# "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -6,6 +6,8 @@ import numpy as np
from matplotlib.font_manager import FontProperties from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar from pyecharts.charts import Bar
from pyecharts import options as opts from pyecharts import options as opts
from test_cls_1 import TestBase,Test1
from test_cls_2 import Test2
CFG = Config() CFG = Config()
@ -56,20 +58,30 @@ CFG = Config()
# #
if __name__ == "__main__": if __name__ == "__main__":
def __extract_json(s): # def __extract_json(s):
i = s.index("{") # i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 # count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1): # for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}": # if c == "}":
count -= 1 # count -= 1
elif c == "{": # elif c == "{":
count += 1 # count += 1
if count == 0: # if count == 0:
break # break
assert count == 0 # 检查是否找到最后一个'}' # assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1] # return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss)) test1 = Test1()
test2 = Test2()
test1.write()
test1.test()
test2.write()
test1.test()
test2.test()

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
class Test1(TestBase):
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")

View File

@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase):
test_2_values:List = []
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
self.test_2_values.append("z")

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
def test(self):
print(self.__class__.__name__ + ":" )
print(self.test_values)

View File

@ -32,14 +32,9 @@ class BaseLLMAdaper:
return True return True
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model_path, use_fast=False, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
) )
return model, tokenizer return model, tokenizer
@ -62,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
raise ValueError(f"Invalid model adapter for {model_path}") raise ValueError(f"Invalid model adapter for {model_path}")
# TODO support cpu? for practice we support gpt4all or chatglm-6b-int4? # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
class VicunaLLMAdapater(BaseLLMAdaper): class VicunaLLMAdapater(BaseLLMAdaper):

View File

@ -11,7 +11,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def chatglm_generate_stream( def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2 model, tokenizer, params, device, context_len=2048, stream_interval=2
): ):
"""Generate text using chatglm model's chat api""" """Generate text using chatglm model's chat api_v1"""
prompt = params["prompt"] prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0)) top_p = float(params.get("top_p", 1.0))

View File

@ -51,7 +51,7 @@ def create_chat_completion(
return message return message
response = None response = None
# TODO impl this use vicuna server api # TODO impl this use vicuna server api_v1
class Stream(transformers.StoppingCriteria): class Stream(transformers.StoppingCriteria):

View File

@ -32,7 +32,7 @@ class BaseOutputParser(ABC):
Output parsers help structure language model responses. Output parsers help structure language model responses.
""" """
def __init__(self, sep: str, is_stream_out: bool): def __init__(self, sep: str, is_stream_out: bool = True):
self.sep = sep self.sep = sep
self.is_stream_out = is_stream_out self.is_stream_out = is_stream_out

View File

@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples: List[List]
use_example: bool = False
type: str = ExampleType.ONE_SHOT.value
def examples(self, count: int = 2):
if ExampleType.ONE_SHOT.value == self.type:
return self.__one_show_context()
else:
return self.__few_shot_context(count)
def __few_shot_context(self, count: int = 2) -> List[List]:
"""
Use 2 or more examples, default 2
Returns: example text
"""
if self.use_example:
need_use = self.examples[:count]
return need_use
return None
def __one_show_context(self) -> List:
"""
Use one examples
Returns:
"""
if self.use_example:
need_use = self.examples[:1]
return need_use
return None

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter from pilot.common.formatting import formatter
from pilot.out_parser.base import BaseOutputParser from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str: def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2.""" """Format a template using jinja2."""
@ -32,7 +32,6 @@ class PromptTemplate(BaseModel, ABC):
input_variables: List[str] input_variables: List[str]
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""
template_scene: Optional[str] template_scene: Optional[str]
template_define: Optional[str] template_define: Optional[str]
"""this template define""" """this template define"""
template: Optional[str] template: Optional[str]
@ -46,6 +45,7 @@ class PromptTemplate(BaseModel, ABC):
output_parser: BaseOutputParser = None output_parser: BaseOutputParser = None
"""""" """"""
sep: str = SeparatorStyle.SINGLE.value sep: str = SeparatorStyle.SINGLE.value
example: ExampleSelector = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -182,7 +182,7 @@ class BasePromptTemplate(BaseModel, ABC):
Example: Example:
.. code-block:: python .. code-block:: python
prompt.save(file_path="path/prompt.yaml") prompt.save(file_path="path/prompt.api_v1")
""" """
if self.partial_variables: if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.") raise ValueError("Cannot save prompt with partial variables.")
@ -201,11 +201,11 @@ class BasePromptTemplate(BaseModel, ABC):
if save_path.suffix == ".json": if save_path.suffix == ".json":
with open(file_path, "w") as f: with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4) json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml": elif save_path.suffix == ".api_v1":
with open(file_path, "w") as f: with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False) yaml.dump(prompt_dict, f, default_flow_style=False)
else: else:
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or api_v1")
class StringPromptValue(PromptValue): class StringPromptValue(PromptValue):

View File

@ -10,3 +10,7 @@ class ChatScene(Enum):
ChatUrlKnowledge = "chat_url_knowledge" ChatUrlKnowledge = "chat_url_knowledge"
InnerChatDBSummary = "inner_chat_db_summary" InnerChatDBSummary = "inner_chat_db_summary"
ChatNormal = "chat_normal" ChatNormal = "chat_normal"
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)

View File

@ -70,7 +70,7 @@ class BaseChat(ABC):
self.current_user_input: str = current_user_input self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL self.llm_model = CFG.LLM_MODEL
### can configurable storage methods ### can configurable storage methods
self.memory = MemHistoryMemory(chat_session_id) self.memory = FileHistoryMemory(chat_session_id)
### load prompt template ### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.prompt_template: PromptTemplate = CFG.prompt_templates[
@ -139,9 +139,7 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try: try:
show_info = ""
response = requests.post( response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"), urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers, headers=headers,
@ -157,16 +155,10 @@ class BaseChat(ABC):
# show_info = resp_text_trunck # show_info = resp_text_trunck
# yield resp_text_trunck + "▌" # yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error("model response parase faild" + str(e)) logger.error("model response parase faild" + str(e))
self.current_message.add_view_message( raise ValueError(str(e))
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
def nostream_call(self): def nostream_call(self):
payload = self.__call_base() payload = self.__call_base()
@ -188,20 +180,6 @@ class BaseChat(ABC):
) )
) )
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = ( prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response( self.prompt_template.output_parser.parse_prompt_response(
@ -293,6 +271,7 @@ class BaseChat(ABC):
return text return text
# 暂时为了兼容前端 # 暂时为了兼容前端
def current_ai_response(self) -> str: def current_ai_response(self) -> str:
for message in self.current_message.messages: for message in self.current_message.messages:

View File

View File

@ -0,0 +1,9 @@
from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -1,36 +1,31 @@
import json import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import example
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" # PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
PROMPT_SUFFIX = """
_DEFAULT_TEMPLATE = """
Goals: Goals:
{input} {input}
"""
_DEFAULT_TEMPLATE = """
Constraints: Constraints:
0.Exclusively use the commands listed in double quotes e.g. "command name" 0.Exclusively use the commands listed in double quotes e.g. "command name"
{constraints} {constraints}
Commands: Commands:
{commands_infos} {commands_infos}
"""
PROMPT_RESPONSE = """
Please response strictly according to the following json format: Please response strictly according to the following json format:
{response} {response}
Ensure the response is correct json and can be parsed by Python json.loads Ensure the response is correct json and can be parsed by Python json.loads
""" """
@ -40,20 +35,22 @@ RESPONSE_FORMAT = {
"command": {"name": "command name", "args": {"arg name": "value"}}, "command": {"name": "command name", "args": {"arg name": "value"}},
} }
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output ### Whether the model service is streaming output
PROMPT_NEED_NEED_STREAM_OUT = False PROMPT_NEED_STREAM_OUT = False
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value, template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"], input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4), response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser( output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT example=example
),
) )
CFG.prompt_templates.update({prompt.template_scene: prompt}) CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1 @@

View File

@ -1,5 +1,3 @@
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat, logger, headers from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -52,19 +50,12 @@ class ChatNewKnowledge(BaseChat):
) )
def generate_input_values(self): def generate_input_values(self):
try: docs = self.knowledge_embedding_client.similar_search(
docs = self.knowledge_embedding_client.similar_search( self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE )
) context = [d.page_content for d in docs]
context = [d.page_content for d in docs] context = context[:2000]
self.metadata = [d.metadata for d in docs] input_values = {"context": context, "question": self.current_user_input}
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
raise ValueError(
f"you have no {self.knowledge_name} knowledge store, please upload your knowledge"
)
return input_values return input_values
def do_with_prompt_response(self, prompt_response): def do_with_prompt_response(self, prompt_response):

View File

@ -43,12 +43,29 @@ class OnceConversation:
def add_ai_message(self, message: str) -> None: def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""
has_message = any(isinstance(instance, AIMessage) for instance in self.messages) has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message: if has_message:
raise ValueError("Already Have Ai message") self.update_ai_message(message)
self.messages.append(AIMessage(content=message)) else:
self.messages.append(AIMessage(content=message))
""" """ """ """
def __update_ai_message(self, new_message:str)-> None:
"""
stream out message update
Args:
new_message:
Returns:
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None: def add_view_message(self, message: str) -> None:
"""Add an AI message to the store""" """Add an AI message to the store"""

View File

@ -0,0 +1,146 @@
import uuid
from fastapi import APIRouter, Request, Body, status
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR)
from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage)
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(message)
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
async def dialogue_list(user_id: str):
#### TODO
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
return Result[ConversationVo].succ(conversations)
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
async def dialogue_new(user_id: str):
unique_id = uuid.uuid1()
return Result.succ(unique_id)
@router.post('/v1/chat/dialogue/delete')
async def dialogue_delete(con_uid: str, user_id: str):
#### TODO
return Result.succ(None)
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
}
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatExecution == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param)
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
chat_param.update("knowledge_name", dialogue.select_param)
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
chat_param.update("url", dialogue.select_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out:
return non_stream_response(chat)
else:
return stream_response(chat)
def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
messageVos = [message2Vo(element) for element in chat.current_message.messages]
yield Result.succ(messageVos)
def stream_response(chat):
logger.info("stream out start!")
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
return api_response
def message2Vo(message:BaseMessage)->MessageVo:
vo:MessageVo = MessageVo()
vo.role = message.type
vo.role = message.content
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
chat.nostream_call()
messageVos = [message2Vo(element) for element in chat.current_message.messages]
return Result.succ(messageVos)
@router.get('/v1/db/types', response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get('/v1/knowledge/list')
async def knowledge_list():
return ["test1", "test2"]
@router.post('/v1/knowledge/add')
async def knowledge_add():
return ["test1", "test2"]
@router.post('/v1/knowledge/delete')
async def knowledge_delete():
return ["test1", "test2"]
@router.get('/v1/knowledge/types')
async def knowledge_types():
return ["test1", "test2"]
@router.get('/v1/knowledge/detail')
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -0,0 +1,57 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic
T = TypeVar('T')
class Result(Generic[T], BaseModel):
success: bool
err_code: str
err_msg: str
data: List[T]
@classmethod
def succ(cls, data: List[T]):
return Result(True, None, None, data)
@classmethod
def faild(cls, msg):
return Result(True, "E000X", msg, None)
@classmethod
def faild(cls, code, msg):
return Result(True, code, msg, None)
class ConversationVo(BaseModel):
"""
dialogue_uid
"""
conv_uid: str = Field(..., description="dialogue uid")
"""
user input
"""
user_input: str
"""
the scene of chat
"""
chat_mode: str = Field(..., description="the scene of chat ")
"""
chat scene select param
"""
select_param: str
class MessageVo(BaseModel):
"""
role that sends out the current message
"""
role: str
"""
current message
"""
context: str
"""
time the current message was sent
"""
time_stamp: float

View File

@ -5,7 +5,6 @@ import asyncio
import json import json
import os import os
import sys import sys
import traceback
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
@ -76,10 +75,8 @@ class ModelWorker:
): ):
# Please do not open the output in production! # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process, # The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output # and opening it may affect the frontend output.
if not ("gptj" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL): # print("output: ", output)
print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,
@ -90,13 +87,10 @@ class ModelWorker:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0} ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
except Exception as e: except Exception as e:
msg = "{}: {}".format(str(e), traceback.format_exc())
ret = { ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}", "text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0, "error_code": 0,
} }
yield json.dumps(ret).encode() + b"\0" yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt): def get_embeddings(self, prompt):

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import signal
import threading import threading
import traceback import traceback
import argparse import argparse
@ -12,12 +11,10 @@ import uuid
import gradio as gr import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH) sys.path.append(ROOT_PATH)
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
@ -25,8 +22,8 @@ from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
DATASETS_DIR, DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
LOGDIR,
LLM_MODEL_CONFIG, LLM_MODEL_CONFIG,
LOGDIR,
) )
from pilot.conversation import ( from pilot.conversation import (
@ -35,7 +32,6 @@ from pilot.conversation import (
chat_mode_title, chat_mode_title,
default_conversation, default_conversation,
) )
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.server.gradio_css import code_highlight_css from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot from pilot.server.gradio_patch import Chatbot as grChatbot
@ -49,6 +45,19 @@ from pilot.vector_store.extract_tovec import (
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text from pilot.language.translation_handler import get_lang_text
from pilot.server.webserver_base import server_init
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件 # 加载插件
CFG = Config() CFG = Config()
@ -95,6 +104,19 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue, add_knowledge_base_dialogue,
] ]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def get_simlar(q): def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md")) docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
@ -324,15 +346,14 @@ def http_bot(
response = chat.stream_call() response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
state.messages[-1][ msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
-1 state.messages[-1][-1] =msg
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex( chat.current_message.add_ai_message(msg)
chunk, chat.skip_echo_len
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e) state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -632,7 +653,7 @@ def knowledge_embedding_store(vs_id, files):
) )
knowledge_embedding_client = KnowledgeEmbedding( knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename), file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL], model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config={ vector_store_config={
"vector_store_name": vector_store_name["vs_name"], "vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH, "vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -657,48 +678,36 @@ def signal_handler(sig, frame):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10) parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true") parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args() args = parser.parse_args()
logger.info(f"args: {args}") server_init(args)
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)
# init config
cfg = Config()
load_native_plugins(cfg)
dbs = cfg.local_db.get_database_list()
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry
logger.info(args)
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
max_threads=200,
)

View File

@ -0,0 +1,60 @@
import signal
import os
import threading
import traceback
import sys
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
def async_db_summery():
client = DBSummaryClient()
thread = threading.Thread(target=client.init_db_summary)
thread.start()
def server_init(args):
logger.info(f"args: {args}")
# init config
cfg = Config()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry

View File

@ -35,7 +35,7 @@ class ElevenLabsSpeech(VoiceBase):
} }
self._headers = { self._headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"xi-api-key": cfg.elevenlabs_api_key, "xi-api_v1-key": cfg.elevenlabs_api_key,
} }
self._voices = default_voices.copy() self._voices = default_voices.copy()
if cfg.elevenlabs_voice_1_id in voice_options: if cfg.elevenlabs_voice_1_id in voice_options:

View File

@ -1,13 +1,12 @@
from pilot.vector_store.chroma_store import ChromaStore from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore # from pilot.vector_store.milvus_store import MilvusStore
from pilot.vector_store.weaviate_store import WeaviateStore
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore} connector = {"Chroma": ChromaStore, "Milvus": None}
class VectorStoreConnector: class VectorStoreConnector:
"""vector store connector, can connect different vector db provided load document api and similar search api.""" """vector store connector, can connect different vector db provided load document api_v1 and similar search api_v1."""
def __init__(self, vector_store_type, ctx: {}) -> None: def __init__(self, vector_store_type, ctx: {}) -> None:
"""initialize vector store connector.""" """initialize vector store connector."""

View File

@ -1,146 +0,0 @@
import os
import json
import weaviate
from langchain.schema import Document
from langchain.vectorstores import Weaviate
from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
class WeaviateStore(VectorStoreBase):
"""Weaviate database"""
def __init__(self, ctx: dict) -> None:
"""Initialize with Weaviate client."""
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
self.ctx = ctx
self.weaviate_url = CFG.WEAVIATE_URL
self.embedding = ctx.get("embeddings", None)
self.vector_name = ctx["vector_store_name"]
self.persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
)
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> None:
"""Perform similar search in Weaviate"""
logger.info("Weaviate similar search")
# nearText = {
# "concepts": [text],
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
# }
# vector = self.embedding.embed_query(text)
response = (
self.vector_store_client.query.get(
self.vector_name, ["metadata", "page_content"]
)
# .with_near_vector({"vector": vector})
.with_limit(topk).do()
)
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
docs = []
for r in res:
docs.append(
Document(
page_content=r["page_content"],
metadata={"metadata": r["metadata"]},
)
)
return docs
def vector_name_exists(self) -> bool:
"""Check if a vector name exists for a given class in Weaviate.
Returns:
bool: True if the vector name exists, False otherwise.
"""
try:
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
except WeaviateBaseError as e:
logger.error("vector_name_exists error", e.message)
return False
def _default_schema(self) -> None:
"""
Create the schema for Weaviate with a Document class containing metadata and text properties.
"""
schema = {
"classes": [
{
"class": self.vector_name,
"description": "A document with metadata and text",
# "moduleConfig": {
# "text2vec-transformers": {
# "poolingStrategy": "masked_mean",
# "vectorizeClassName": False,
# }
# },
"properties": [
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Metadata of the document",
"name": "metadata",
},
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Text content of the document",
"name": "page_content",
},
],
# "vectorizer": "text2vec-transformers",
}
]
}
# Create the schema in Weaviate
self.vector_store_client.schema.create(schema)
def load_document(self, documents: list) -> None:
"""Load documents into Weaviate"""
logger.info("Weaviate load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
# Import data
with self.vector_store_client.batch as batch:
batch.batch_size = 100
# Batch import all documents
for i in range(len(texts)):
properties = {
"metadata": metadatas[i]["source"],
"page_content": texts[i],
}
self.vector_store_client.batch.add_data_object(
data_object=properties, class_name=self.vector_name
)
self.vector_store_client.batch.flush()