WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-25 14:46:46 +08:00
parent 0558a8ba37
commit d372e73cd5
32 changed files with 506 additions and 309 deletions

View File

@ -78,7 +78,6 @@ def load_native_plugins(cfg: Config):
if not cfg.plugins_auto_load:
print("not auto load_native_plugins")
return
def load_from_git(cfg: Config):
print("async load_native_plugins")
branch_name = cfg.plugins_git_branch
@ -86,20 +85,16 @@ def load_native_plugins(cfg: Config):
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
try:
session = requests.Session()
response = session.get(
url.format(repo=native_plugin_repo, branch=branch_name),
headers={"Authorization": "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"},
)
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
headers={'Authorization': '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'})
if response.status_code == 200:
plugins_path_path = Path(PLUGINS_DIR)
files = glob.glob(
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
)
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
for file in files:
os.remove(file)
now = datetime.datetime.now()
time_str = now.strftime("%Y%m%d%H%M%S")
time_str = now.strftime('%Y%m%d%H%M%S')
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
print(file_name)
with open(file_name, "wb") as f:
@ -115,6 +110,7 @@ def load_native_plugins(cfg: Config):
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
"""Scan the plugins directory for plugins and loads them.

View File

@ -7,3 +7,7 @@ class SeparatorStyle(Enum):
TWO = "</s>"
THREE = auto()
FOUR = auto()
class ExampleType(Enum):
ONE_SHOT = "one_shot"
FEW_SHOT = "few_shot"

View File

@ -150,8 +150,6 @@ class Config(metaclass=Singleton):
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080")
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")

View File

@ -43,8 +43,6 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
# TODO Support baichuan-7b
# "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm",
}

View File

@ -6,6 +6,8 @@ import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
from test_cls_1 import TestBase,Test1
from test_cls_2 import Test2
CFG = Config()
@ -56,20 +58,30 @@ CFG = Config()
#
if __name__ == "__main__":
def __extract_json(s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
print(__extract_json(ss))
test1 = Test1()
test2 = Test2()
test1.write()
test1.test()
test2.write()
test1.test()
test2.test()

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
class Test1(TestBase):
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")

View File

@ -0,0 +1,15 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase):
test_2_values:List = []
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
self.test_2_values.append("z")

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
def test(self):
print(self.__class__.__name__ + ":" )
print(self.test_values)

View File

@ -32,14 +32,9 @@ class BaseLLMAdaper:
return True
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=False, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
**from_pretrained_kwargs,
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer
@ -62,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
raise ValueError(f"Invalid model adapter for {model_path}")
# TODO support cpu? for practice we support gpt4all or chatglm-6b-int4?
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
class VicunaLLMAdapater(BaseLLMAdaper):

View File

@ -11,7 +11,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def chatglm_generate_stream(
model, tokenizer, params, device, context_len=2048, stream_interval=2
):
"""Generate text using chatglm model's chat api"""
"""Generate text using chatglm model's chat api_v1"""
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))

View File

@ -51,7 +51,7 @@ def create_chat_completion(
return message
response = None
# TODO impl this use vicuna server api
# TODO impl this use vicuna server api_v1
class Stream(transformers.StoppingCriteria):

View File

@ -32,7 +32,7 @@ class BaseOutputParser(ABC):
Output parsers help structure language model responses.
"""
def __init__(self, sep: str, is_stream_out: bool):
def __init__(self, sep: str, is_stream_out: bool = True):
self.sep = sep
self.is_stream_out = is_stream_out

View File

@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples: List[List]
use_example: bool = False
type: str = ExampleType.ONE_SHOT.value
def examples(self, count: int = 2):
if ExampleType.ONE_SHOT.value == self.type:
return self.__one_show_context()
else:
return self.__few_shot_context(count)
def __few_shot_context(self, count: int = 2) -> List[List]:
"""
Use 2 or more examples, default 2
Returns: example text
"""
if self.use_example:
need_use = self.examples[:count]
return need_use
return None
def __one_show_context(self) -> List:
"""
Use one examples
Returns:
"""
if self.use_example:
need_use = self.examples[:1]
return need_use
return None

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter
from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
def jinja2_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using jinja2."""
@ -32,7 +32,6 @@ class PromptTemplate(BaseModel, ABC):
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
template_scene: Optional[str]
template_define: Optional[str]
"""this template define"""
template: Optional[str]
@ -46,6 +45,7 @@ class PromptTemplate(BaseModel, ABC):
output_parser: BaseOutputParser = None
""""""
sep: str = SeparatorStyle.SINGLE.value
example: ExampleSelector = None
class Config:
"""Configuration for this pydantic object."""

View File

@ -182,7 +182,7 @@ class BasePromptTemplate(BaseModel, ABC):
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
prompt.save(file_path="path/prompt.api_v1")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
@ -201,11 +201,11 @@ class BasePromptTemplate(BaseModel, ABC):
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
elif save_path.suffix == ".api_v1":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
raise ValueError(f"{save_path} must be json or api_v1")
class StringPromptValue(PromptValue):

View File

@ -10,3 +10,7 @@ class ChatScene(Enum):
ChatUrlKnowledge = "chat_url_knowledge"
InnerChatDBSummary = "inner_chat_db_summary"
ChatNormal = "chat_normal"
@staticmethod
def is_valid_mode(mode):
return any(mode == item.value for item in ChatScene)

View File

@ -70,7 +70,7 @@ class BaseChat(ABC):
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
### can configurable storage methods
self.memory = MemHistoryMemory(chat_session_id)
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[
@ -139,9 +139,7 @@ class BaseChat(ABC):
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
show_info = ""
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
@ -157,16 +155,10 @@ class BaseChat(ABC):
# show_info = resp_text_trunck
# yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
self.memory.append(self.current_message)
raise ValueError(str(e))
def nostream_call(self):
payload = self.__call_base()
@ -188,20 +180,6 @@ class BaseChat(ABC):
)
)
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
@ -293,6 +271,7 @@ class BaseChat(ABC):
return text
# 暂时为了兼容前端
def current_ai_response(self) -> str:
for message in self.current_message.messages:

View File

View File

@ -0,0 +1,9 @@
from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -1,36 +1,31 @@
import json
import importlib
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle
from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import example
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
PROMPT_SUFFIX = """
_DEFAULT_TEMPLATE = """
Goals:
{input}
"""
_DEFAULT_TEMPLATE = """
Constraints:
0.Exclusively use the commands listed in double quotes e.g. "command name"
{constraints}
Commands:
{commands_infos}
"""
PROMPT_RESPONSE = """
Please response strictly according to the following json format:
{response}
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
@ -40,20 +35,22 @@ RESPONSE_FORMAT = {
"command": {"name": "command name", "args": {"arg name": "value"}},
}
EXAMPLE_TYPE = ExampleType.ONE_SHOT
PROMPT_SEP = SeparatorStyle.SINGLE.value
### Whether the model service is streaming output
PROMPT_NEED_NEED_STREAM_OUT = False
PROMPT_NEED_STREAM_OUT = False
prompt = PromptTemplate(
template_scene=ChatScene.ChatExecution.value,
input_variables=["input", "constraints", "commands_infos", "response"],
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_STREAM_OUT,
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
example=example
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -0,0 +1 @@

View File

@ -1,5 +1,3 @@
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat, logger, headers
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -52,19 +50,12 @@ class ChatNewKnowledge(BaseChat):
)
def generate_input_values(self):
try:
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
)
context = [d.page_content for d in docs]
self.metadata = [d.metadata for d in docs]
context = context[:2000]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
raise ValueError(
f"you have no {self.knowledge_name} knowledge store, please upload your knowledge"
)
return input_values
def do_with_prompt_response(self, prompt_response):

View File

@ -43,12 +43,29 @@ class OnceConversation:
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message:
raise ValueError("Already Have Ai message")
self.update_ai_message(message)
else:
self.messages.append(AIMessage(content=message))
""" """
def __update_ai_message(self, new_message:str)-> None:
"""
stream out message update
Args:
new_message:
Returns:
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""

View File

@ -0,0 +1,146 @@
import uuid
from fastapi import APIRouter, Request, Body, status
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from typing import List
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
from pilot.configs.config import Config
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import (LOGDIR)
from pilot.utils import build_logger
from pilot.scene.base_message import (BaseMessage)
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
for error in exc.errors():
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
return Result.faild(message)
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
async def dialogue_list(user_id: str):
#### TODO
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
return Result[ConversationVo].succ(conversations)
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
async def dialogue_new(user_id: str):
unique_id = uuid.uuid1()
return Result.succ(unique_id)
@router.post('/v1/chat/dialogue/delete')
async def dialogue_delete(con_uid: str, user_id: str):
#### TODO
return Result.succ(None)
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
}
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
chat_param.update("db_name", dialogue.select_param)
elif ChatScene.ChatExecution == dialogue.chat_mode:
chat_param.update("plugin_selector", dialogue.select_param)
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
chat_param.update("knowledge_name", dialogue.select_param)
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
chat_param.update("url", dialogue.select_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out:
return non_stream_response(chat)
else:
return stream_response(chat)
def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
messageVos = [message2Vo(element) for element in chat.current_message.messages]
yield Result.succ(messageVos)
def stream_response(chat):
logger.info("stream out start!")
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
return api_response
def message2Vo(message:BaseMessage)->MessageVo:
vo:MessageVo = MessageVo()
vo.role = message.type
vo.role = message.content
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
chat.nostream_call()
messageVos = [message2Vo(element) for element in chat.current_message.messages]
return Result.succ(messageVos)
@router.get('/v1/db/types', response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get('/v1/knowledge/list')
async def knowledge_list():
return ["test1", "test2"]
@router.post('/v1/knowledge/add')
async def knowledge_add():
return ["test1", "test2"]
@router.post('/v1/knowledge/delete')
async def knowledge_delete():
return ["test1", "test2"]
@router.get('/v1/knowledge/types')
async def knowledge_types():
return ["test1", "test2"]
@router.get('/v1/knowledge/detail')
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -0,0 +1,57 @@
from pydantic import BaseModel, Field
from typing import TypeVar, Union, List, Generic
T = TypeVar('T')
class Result(Generic[T], BaseModel):
success: bool
err_code: str
err_msg: str
data: List[T]
@classmethod
def succ(cls, data: List[T]):
return Result(True, None, None, data)
@classmethod
def faild(cls, msg):
return Result(True, "E000X", msg, None)
@classmethod
def faild(cls, code, msg):
return Result(True, code, msg, None)
class ConversationVo(BaseModel):
"""
dialogue_uid
"""
conv_uid: str = Field(..., description="dialogue uid")
"""
user input
"""
user_input: str
"""
the scene of chat
"""
chat_mode: str = Field(..., description="the scene of chat ")
"""
chat scene select param
"""
select_param: str
class MessageVo(BaseModel):
"""
role that sends out the current message
"""
role: str
"""
current message
"""
context: str
"""
time the current message was sent
"""
time_stamp: float

View File

@ -5,7 +5,6 @@ import asyncio
import json
import os
import sys
import traceback
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
@ -76,10 +75,8 @@ class ModelWorker:
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output
if not ("gptj" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL):
print("output: ", output)
# and opening it may affect the frontend output.
# print("output: ", output)
ret = {
"text": output,
"error_code": 0,
@ -90,13 +87,10 @@ class ModelWorker:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
msg = "{}: {}".format(str(e), traceback.format_exc())
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}",
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import signal
import threading
import traceback
import argparse
@ -12,12 +11,10 @@ import uuid
import gradio as gr
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.scene.base_chat import BaseChat
@ -25,8 +22,8 @@ from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LOGDIR,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.conversation import (
@ -35,7 +32,6 @@ from pilot.conversation import (
chat_mode_title,
default_conversation,
)
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.server.gradio_css import code_highlight_css
from pilot.server.gradio_patch import Chatbot as grChatbot
@ -49,6 +45,19 @@ from pilot.vector_store.extract_tovec import (
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
from pilot.server.webserver_base import server_init
import uvicorn
from fastapi import BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.staticfiles import StaticFiles
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
# 加载插件
CFG = Config()
@ -95,6 +104,19 @@ knowledge_qa_type_list = [
add_knowledge_base_dialogue,
]
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
@ -324,15 +346,14 @@ def http_bot(
response = chat.stream_call()
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
state.messages[-1][
-1
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
chunk, chat.skip_echo_len
)
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
state.messages[-1][-1] =msg
chat.current_message.add_ai_message(msg)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
chat.memory.append(chat.current_message)
except Exception as e:
print(traceback.format_exc())
state.messages[-1][-1] = "Error:" + str(e)
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
@ -632,7 +653,7 @@ def knowledge_embedding_store(vs_id, files):
)
knowledge_embedding_client = KnowledgeEmbedding(
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
model_name=LLM_MODEL_CONFIG["text2vec"],
vector_store_config={
"vector_store_name": vector_store_name["vs_name"],
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
@ -657,42 +678,25 @@ def signal_handler(sig, frame):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
logger.info(f"args: {args}")
server_init(args)
# init config
cfg = Config()
load_native_plugins(cfg)
dbs = cfg.local_db.get_database_list()
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry
logger.info(args)
if args.new:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)
else:
### Compatibility mode starts the old version server by default
demo = build_webdemo()
demo.queue(
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
@ -702,3 +706,8 @@ if __name__ == "__main__":
share=args.share,
max_threads=200,
)

View File

@ -0,0 +1,60 @@
import signal
import os
import threading
import traceback
import sys
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.common.plugins import scan_plugins, load_native_plugins
from pilot.utils import build_logger
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
logger = build_logger("webserver", LOGDIR + "webserver.log")
def signal_handler(sig, frame):
print("in order to avoid chroma db atexit problem")
os._exit(0)
def async_db_summery():
client = DBSummaryClient()
thread = threading.Thread(target=client.init_db_summary)
thread.start()
def server_init(args):
logger.info(f"args: {args}")
# init config
cfg = Config()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
# Loader plugins and commands
command_categories = [
"pilot.commands.built_in.audio_text",
"pilot.commands.built_in.image_gen",
]
# exclude commands
command_categories = [
x for x in command_categories if x not in cfg.disabled_command_categories
]
command_registry = CommandRegistry()
for command_category in command_categories:
command_registry.import_commands(command_category)
cfg.command_registry = command_registry

View File

@ -35,7 +35,7 @@ class ElevenLabsSpeech(VoiceBase):
}
self._headers = {
"Content-Type": "application/json",
"xi-api-key": cfg.elevenlabs_api_key,
"xi-api_v1-key": cfg.elevenlabs_api_key,
}
self._voices = default_voices.copy()
if cfg.elevenlabs_voice_1_id in voice_options:

View File

@ -1,13 +1,12 @@
from pilot.vector_store.chroma_store import ChromaStore
from pilot.vector_store.milvus_store import MilvusStore
from pilot.vector_store.weaviate_store import WeaviateStore
# from pilot.vector_store.milvus_store import MilvusStore
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
connector = {"Chroma": ChromaStore, "Milvus": None}
class VectorStoreConnector:
"""vector store connector, can connect different vector db provided load document api and similar search api."""
"""vector store connector, can connect different vector db provided load document api_v1 and similar search api_v1."""
def __init__(self, vector_store_type, ctx: {}) -> None:
"""initialize vector store connector."""

View File

@ -1,146 +0,0 @@
import os
import json
import weaviate
from langchain.schema import Document
from langchain.vectorstores import Weaviate
from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
class WeaviateStore(VectorStoreBase):
"""Weaviate database"""
def __init__(self, ctx: dict) -> None:
"""Initialize with Weaviate client."""
try:
import weaviate
except ImportError:
raise ValueError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
self.ctx = ctx
self.weaviate_url = CFG.WEAVIATE_URL
self.embedding = ctx.get("embeddings", None)
self.vector_name = ctx["vector_store_name"]
self.persist_dir = os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
)
self.vector_store_client = weaviate.Client(self.weaviate_url)
def similar_search(self, text: str, topk: int) -> None:
"""Perform similar search in Weaviate"""
logger.info("Weaviate similar search")
# nearText = {
# "concepts": [text],
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
# }
# vector = self.embedding.embed_query(text)
response = (
self.vector_store_client.query.get(
self.vector_name, ["metadata", "page_content"]
)
# .with_near_vector({"vector": vector})
.with_limit(topk).do()
)
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
docs = []
for r in res:
docs.append(
Document(
page_content=r["page_content"],
metadata={"metadata": r["metadata"]},
)
)
return docs
def vector_name_exists(self) -> bool:
"""Check if a vector name exists for a given class in Weaviate.
Returns:
bool: True if the vector name exists, False otherwise.
"""
try:
if self.vector_store_client.schema.get(self.vector_name):
return True
return False
except WeaviateBaseError as e:
logger.error("vector_name_exists error", e.message)
return False
def _default_schema(self) -> None:
"""
Create the schema for Weaviate with a Document class containing metadata and text properties.
"""
schema = {
"classes": [
{
"class": self.vector_name,
"description": "A document with metadata and text",
# "moduleConfig": {
# "text2vec-transformers": {
# "poolingStrategy": "masked_mean",
# "vectorizeClassName": False,
# }
# },
"properties": [
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Metadata of the document",
"name": "metadata",
},
{
"dataType": ["text"],
# "moduleConfig": {
# "text2vec-transformers": {
# "skip": False,
# "vectorizePropertyName": False,
# }
# },
"description": "Text content of the document",
"name": "page_content",
},
],
# "vectorizer": "text2vec-transformers",
}
]
}
# Create the schema in Weaviate
self.vector_store_client.schema.create(schema)
def load_document(self, documents: list) -> None:
"""Load documents into Weaviate"""
logger.info("Weaviate load document")
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
# Import data
with self.vector_store_client.batch as batch:
batch.batch_size = 100
# Batch import all documents
for i in range(len(texts)):
properties = {
"metadata": metadatas[i]["source"],
"page_content": texts[i],
}
self.vector_store_client.batch.add_data_object(
data_object=properties, class_name=self.vector_name
)
self.vector_store_client.batch.flush()