mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
WEB API independent
This commit is contained in:
parent
0558a8ba37
commit
d372e73cd5
@ -78,7 +78,6 @@ def load_native_plugins(cfg: Config):
|
||||
if not cfg.plugins_auto_load:
|
||||
print("not auto load_native_plugins")
|
||||
return
|
||||
|
||||
def load_from_git(cfg: Config):
|
||||
print("async load_native_plugins")
|
||||
branch_name = cfg.plugins_git_branch
|
||||
@ -86,20 +85,16 @@ def load_native_plugins(cfg: Config):
|
||||
url = "https://github.com/csunny/{repo}/archive/{branch}.zip"
|
||||
try:
|
||||
session = requests.Session()
|
||||
response = session.get(
|
||||
url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={"Authorization": "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"},
|
||||
)
|
||||
response = session.get(url.format(repo=native_plugin_repo, branch=branch_name),
|
||||
headers={'Authorization': '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'})
|
||||
|
||||
if response.status_code == 200:
|
||||
plugins_path_path = Path(PLUGINS_DIR)
|
||||
files = glob.glob(
|
||||
os.path.join(plugins_path_path, f"{native_plugin_repo}*")
|
||||
)
|
||||
files = glob.glob(os.path.join(plugins_path_path, f'{native_plugin_repo}*'))
|
||||
for file in files:
|
||||
os.remove(file)
|
||||
now = datetime.datetime.now()
|
||||
time_str = now.strftime("%Y%m%d%H%M%S")
|
||||
time_str = now.strftime('%Y%m%d%H%M%S')
|
||||
file_name = f"{plugins_path_path}/{native_plugin_repo}-{branch_name}-{time_str}.zip"
|
||||
print(file_name)
|
||||
with open(file_name, "wb") as f:
|
||||
@ -115,6 +110,7 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
|
@ -7,3 +7,7 @@ class SeparatorStyle(Enum):
|
||||
TWO = "</s>"
|
||||
THREE = auto()
|
||||
FOUR = auto()
|
||||
|
||||
class ExampleType(Enum):
|
||||
ONE_SHOT = "one_shot"
|
||||
FEW_SHOT = "few_shot"
|
||||
|
@ -150,8 +150,6 @@ class Config(metaclass=Singleton):
|
||||
self.MILVUS_USERNAME = os.getenv("MILVUS_USERNAME", None)
|
||||
self.MILVUS_PASSWORD = os.getenv("MILVUS_PASSWORD", None)
|
||||
|
||||
self.WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://127.0.0.1:8080")
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
|
||||
|
@ -43,8 +43,6 @@ LLM_MODEL_CONFIG = {
|
||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
# TODO Support baichuan-7b
|
||||
# "baichuan-7b" : os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
||||
"proxyllm": "proxyllm",
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
from test_cls_1 import TestBase,Test1
|
||||
from test_cls_2 import Test2
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -56,20 +58,30 @@ CFG = Config()
|
||||
#
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def __extract_json(s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
count += 1
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
|
||||
ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
print(__extract_json(ss))
|
||||
|
||||
test1 = Test1()
|
||||
test2 = Test2()
|
||||
test1.write()
|
||||
test1.test()
|
||||
test2.write()
|
||||
test1.test()
|
||||
test2.test()
|
||||
|
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_1.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
|
||||
|
||||
class Test1(TestBase):
|
||||
|
||||
def write(self):
|
||||
self.test_values.append("x")
|
||||
self.test_values.append("y")
|
||||
self.test_values.append("g")
|
||||
|
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
15
pilot/connections/rdbms/py_study/test_cls_2.py
Normal file
@ -0,0 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
class Test2(TestBase):
|
||||
test_2_values:List = []
|
||||
|
||||
def write(self):
|
||||
self.test_values.append(1)
|
||||
self.test_values.append(2)
|
||||
self.test_values.append(3)
|
||||
self.test_2_values.append("x")
|
||||
self.test_2_values.append("y")
|
||||
self.test_2_values.append("z")
|
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
12
pilot/connections/rdbms/py_study/test_cls_base.py
Normal file
@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
|
||||
class TestBase(BaseModel, ABC):
|
||||
test_values: List = []
|
||||
|
||||
|
||||
def test(self):
|
||||
print(self.__class__.__name__ + ":" )
|
||||
print(self.test_values)
|
@ -32,14 +32,9 @@ class BaseLLMAdaper:
|
||||
return True
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, use_fast=False, trust_remote_code=True
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
**from_pretrained_kwargs,
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
@ -62,7 +57,7 @@ def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
|
||||
raise ValueError(f"Invalid model adapter for {model_path}")
|
||||
|
||||
|
||||
# TODO support cpu? for practice we support gpt4all or chatglm-6b-int4?
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
|
||||
class VicunaLLMAdapater(BaseLLMAdaper):
|
||||
|
@ -11,7 +11,7 @@ from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
def chatglm_generate_stream(
|
||||
model, tokenizer, params, device, context_len=2048, stream_interval=2
|
||||
):
|
||||
"""Generate text using chatglm model's chat api"""
|
||||
"""Generate text using chatglm model's chat api_v1"""
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
|
@ -51,7 +51,7 @@ def create_chat_completion(
|
||||
return message
|
||||
|
||||
response = None
|
||||
# TODO impl this use vicuna server api
|
||||
# TODO impl this use vicuna server api_v1
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
|
@ -32,7 +32,7 @@ class BaseOutputParser(ABC):
|
||||
Output parsers help structure language model responses.
|
||||
"""
|
||||
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
def __init__(self, sep: str, is_stream_out: bool = True):
|
||||
self.sep = sep
|
||||
self.is_stream_out = is_stream_out
|
||||
|
||||
|
38
pilot/prompts/example_base.py
Normal file
38
pilot/prompts/example_base.py
Normal file
@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
from pilot.common.schema import ExampleType
|
||||
|
||||
class ExampleSelector(BaseModel, ABC):
|
||||
examples: List[List]
|
||||
use_example: bool = False
|
||||
type: str = ExampleType.ONE_SHOT.value
|
||||
|
||||
def examples(self, count: int = 2):
|
||||
if ExampleType.ONE_SHOT.value == self.type:
|
||||
return self.__one_show_context()
|
||||
else:
|
||||
return self.__few_shot_context(count)
|
||||
|
||||
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||
"""
|
||||
Use 2 or more examples, default 2
|
||||
Returns: example text
|
||||
"""
|
||||
if self.use_example:
|
||||
need_use = self.examples[:count]
|
||||
return need_use
|
||||
return None
|
||||
|
||||
def __one_show_context(self) -> List:
|
||||
"""
|
||||
Use one examples
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.use_example:
|
||||
need_use = self.examples[:1]
|
||||
return need_use
|
||||
|
||||
return None
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pilot.common.formatting import formatter
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
"""Format a template using jinja2."""
|
||||
@ -32,7 +32,6 @@ class PromptTemplate(BaseModel, ABC):
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
template_scene: Optional[str]
|
||||
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
template: Optional[str]
|
||||
@ -46,6 +45,7 @@ class PromptTemplate(BaseModel, ABC):
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep: str = SeparatorStyle.SINGLE.value
|
||||
example: ExampleSelector = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -182,7 +182,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt.save(file_path="path/prompt.yaml")
|
||||
prompt.save(file_path="path/prompt.api_v1")
|
||||
"""
|
||||
if self.partial_variables:
|
||||
raise ValueError("Cannot save prompt with partial variables.")
|
||||
@ -201,11 +201,11 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
elif save_path.suffix == ".api_v1":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
raise ValueError(f"{save_path} must be json or api_v1")
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
|
@ -10,3 +10,7 @@ class ChatScene(Enum):
|
||||
ChatUrlKnowledge = "chat_url_knowledge"
|
||||
InnerChatDBSummary = "inner_chat_db_summary"
|
||||
ChatNormal = "chat_normal"
|
||||
|
||||
@staticmethod
|
||||
def is_valid_mode(mode):
|
||||
return any(mode == item.value for item in ChatScene)
|
||||
|
@ -70,7 +70,7 @@ class BaseChat(ABC):
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
### can configurable storage methods
|
||||
self.memory = MemHistoryMemory(chat_session_id)
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
@ -139,9 +139,7 @@ class BaseChat(ABC):
|
||||
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
show_info = ""
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
@ -157,16 +155,10 @@ class BaseChat(ABC):
|
||||
# show_info = resp_text_trunck
|
||||
# yield resp_text_trunck + "▌"
|
||||
|
||||
self.current_message.add_ai_message(show_info)
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
self.current_message.add_view_message(
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
### 对话记录存储
|
||||
self.memory.append(self.current_message)
|
||||
raise ValueError(str(e))
|
||||
|
||||
def nostream_call(self):
|
||||
payload = self.__call_base()
|
||||
@ -188,20 +180,6 @@ class BaseChat(ABC):
|
||||
)
|
||||
)
|
||||
|
||||
# ### MOCK
|
||||
# ai_response_text = """{
|
||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||
# "command": {
|
||||
# "name": "histogram-executor",
|
||||
# "args": {
|
||||
# "title": "订单城市分布柱状图",
|
||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||
# }
|
||||
# }
|
||||
# }"""
|
||||
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
@ -293,6 +271,7 @@ class BaseChat(ABC):
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# 暂时为了兼容前端
|
||||
def current_ai_response(self) -> str:
|
||||
for message in self.current_message.messages:
|
||||
|
0
pilot/scene/chat_dashboard/__init__.py
Normal file
0
pilot/scene/chat_dashboard/__init__.py
Normal file
9
pilot/scene/chat_execution/example.py
Normal file
9
pilot/scene/chat_execution/example.py
Normal file
@ -0,0 +1,9 @@
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
||||
]
|
||||
|
||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
@ -1,36 +1,31 @@
|
||||
import json
|
||||
import importlib
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||
|
||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
|
||||
from pilot.scene.chat_execution.example import example
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
|
||||
PROMPT_SUFFIX = """
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Goals:
|
||||
{input}
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Constraints:
|
||||
0.Exclusively use the commands listed in double quotes e.g. "command name"
|
||||
{constraints}
|
||||
|
||||
Commands:
|
||||
{commands_infos}
|
||||
"""
|
||||
|
||||
|
||||
PROMPT_RESPONSE = """
|
||||
Please response strictly according to the following json format:
|
||||
{response}
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
@ -40,20 +35,22 @@ RESPONSE_FORMAT = {
|
||||
"command": {"name": "command name", "args": {"arg name": "value"}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
EXAMPLE_TYPE = ExampleType.ONE_SHOT
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
### Whether the model service is streaming output
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
PROMPT_NEED_STREAM_OUT = False
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExecution.value,
|
||||
input_variables=["input", "constraints", "commands_infos", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=PROMPT_SUFFIX + _DEFAULT_TEMPLATE + PROMPT_RESPONSE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_STREAM_OUT,
|
||||
output_parser=PluginChatOutputParser(sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT),
|
||||
example=example
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
1
pilot/scene/chat_execution/prompt_v2.py
Normal file
@ -0,0 +1 @@
|
||||
|
@ -1,5 +1,3 @@
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat, logger, headers
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -52,19 +50,12 @@ class ChatNewKnowledge(BaseChat):
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, CFG.KNOWLEDGE_SEARCH_TOP_SIZE
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
self.metadata = [d.metadata for d in docs]
|
||||
context = context[:2000]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
except NoIndexException:
|
||||
raise ValueError(
|
||||
f"you have no {self.knowledge_name} knowledge store, please upload your knowledge"
|
||||
)
|
||||
|
||||
return input_values
|
||||
|
||||
def do_with_prompt_response(self, prompt_response):
|
||||
|
@ -43,12 +43,29 @@ class OnceConversation:
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Ai message")
|
||||
self.update_ai_message(message)
|
||||
else:
|
||||
self.messages.append(AIMessage(content=message))
|
||||
""" """
|
||||
|
||||
|
||||
def __update_ai_message(self, new_message:str)-> None:
|
||||
"""
|
||||
stream out message update
|
||||
Args:
|
||||
new_message:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
for item in self.messages:
|
||||
if item.type == "ai":
|
||||
item.content = new_message
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
|
146
pilot/server/api_v1/api_v1.py
Normal file
146
pilot/server/api_v1/api_v1.py
Normal file
@ -0,0 +1,146 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Request, Body, status
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import List
|
||||
|
||||
from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import (LOGDIR)
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.base_message import (BaseMessage)
|
||||
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
message = ""
|
||||
for error in exc.errors():
|
||||
message += ".".join(error.get("loc")) + ":" + error.get("msg") + ";"
|
||||
return Result.faild(message)
|
||||
|
||||
|
||||
@router.get('/v1/chat/dialogue/list', response_model=Result[List[ConversationVo]])
|
||||
async def dialogue_list(user_id: str):
|
||||
#### TODO
|
||||
|
||||
conversations = [ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]"),
|
||||
ConversationVo(conv_uid="123", chat_mode="user", select_param="test1", user_input="message[0]")]
|
||||
|
||||
return Result[ConversationVo].succ(conversations)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/new', response_model=Result[str])
|
||||
async def dialogue_new(user_id: str):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(unique_id)
|
||||
|
||||
|
||||
@router.post('/v1/chat/dialogue/delete')
|
||||
async def dialogue_delete(con_uid: str, user_id: str):
|
||||
#### TODO
|
||||
return Result.succ(None)
|
||||
|
||||
|
||||
@router.post('/v1/chat/completions', response_model=Result[MessageVo])
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||
|
||||
chat_param = {
|
||||
"chat_session_id": dialogue.conv_uid,
|
||||
"user_input": dialogue.user_input,
|
||||
}
|
||||
|
||||
if ChatScene.ChatWithDbExecute == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatWithDbQA == dialogue.chat_mode:
|
||||
chat_param.update("db_name", dialogue.select_param)
|
||||
elif ChatScene.ChatExecution == dialogue.chat_mode:
|
||||
chat_param.update("plugin_selector", dialogue.select_param)
|
||||
elif ChatScene.ChatNewKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("knowledge_name", dialogue.select_param)
|
||||
elif ChatScene.ChatUrlKnowledge == dialogue.chat_mode:
|
||||
chat_param.update("url", dialogue.select_param)
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
if not chat.prompt_template.stream_out:
|
||||
return non_stream_response(chat)
|
||||
else:
|
||||
return stream_response(chat)
|
||||
|
||||
|
||||
def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||
yield Result.succ(messageVos)
|
||||
def stream_response(chat):
|
||||
logger.info("stream out start!")
|
||||
api_response = StreamingResponse(stream_generator(chat), media_type="application/json")
|
||||
return api_response
|
||||
|
||||
def message2Vo(message:BaseMessage)->MessageVo:
|
||||
vo:MessageVo = MessageVo()
|
||||
vo.role = message.type
|
||||
vo.role = message.content
|
||||
vo.time_stamp = message.additional_kwargs.time_stamp if message.additional_kwargs["time_stamp"] else 0
|
||||
|
||||
def non_stream_response(chat):
|
||||
logger.info("not stream out, wait model response!")
|
||||
chat.nostream_call()
|
||||
messageVos = [message2Vo(element) for element in chat.current_message.messages]
|
||||
return Result.succ(messageVos)
|
||||
|
||||
|
||||
@router.get('/v1/db/types', response_model=Result[str])
|
||||
async def db_types():
|
||||
return Result.succ(["mysql", "duckdb"])
|
||||
|
||||
|
||||
@router.get('/v1/db/list', response_model=Result[str])
|
||||
async def db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
return Result.succ(dbs)
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/list')
|
||||
async def knowledge_list():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/add')
|
||||
async def knowledge_add():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/delete')
|
||||
async def knowledge_delete():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/types')
|
||||
async def knowledge_types():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/detail')
|
||||
async def knowledge_detail():
|
||||
return ["test1", "test2"]
|
57
pilot/server/api_v1/api_view_model.py
Normal file
57
pilot/server/api_v1/api_view_model.py
Normal file
@ -0,0 +1,57 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TypeVar, Union, List, Generic
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class Result(Generic[T], BaseModel):
|
||||
success: bool
|
||||
err_code: str
|
||||
err_msg: str
|
||||
data: List[T]
|
||||
|
||||
@classmethod
|
||||
def succ(cls, data: List[T]):
|
||||
return Result(True, None, None, data)
|
||||
|
||||
@classmethod
|
||||
def faild(cls, msg):
|
||||
return Result(True, "E000X", msg, None)
|
||||
|
||||
@classmethod
|
||||
def faild(cls, code, msg):
|
||||
return Result(True, code, msg, None)
|
||||
|
||||
|
||||
class ConversationVo(BaseModel):
|
||||
"""
|
||||
dialogue_uid
|
||||
"""
|
||||
conv_uid: str = Field(..., description="dialogue uid")
|
||||
"""
|
||||
user input
|
||||
"""
|
||||
user_input: str
|
||||
"""
|
||||
the scene of chat
|
||||
"""
|
||||
chat_mode: str = Field(..., description="the scene of chat ")
|
||||
"""
|
||||
chat scene select param
|
||||
"""
|
||||
select_param: str
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
"""
|
||||
role that sends out the current message
|
||||
"""
|
||||
role: str
|
||||
"""
|
||||
current message
|
||||
"""
|
||||
context: str
|
||||
"""
|
||||
time the current message was sent
|
||||
"""
|
||||
time_stamp: float
|
@ -5,7 +5,6 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
@ -76,10 +75,8 @@ class ModelWorker:
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output
|
||||
if not ("gptj" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL):
|
||||
print("output: ", output)
|
||||
|
||||
# and opening it may affect the frontend output.
|
||||
# print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
@ -90,13 +87,10 @@ class ModelWorker:
|
||||
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
except Exception as e:
|
||||
msg = "{}: {}".format(str(e), traceback.format_exc())
|
||||
|
||||
ret = {
|
||||
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {msg}",
|
||||
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
"error_code": 0,
|
||||
}
|
||||
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
def get_embeddings(self, prompt):
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import signal
|
||||
import threading
|
||||
import traceback
|
||||
import argparse
|
||||
@ -12,12 +11,10 @@ import uuid
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
|
||||
@ -25,8 +22,8 @@ from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LOGDIR,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
|
||||
from pilot.conversation import (
|
||||
@ -35,7 +32,6 @@ from pilot.conversation import (
|
||||
chat_mode_title,
|
||||
default_conversation,
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
|
||||
from pilot.server.gradio_css import code_highlight_css
|
||||
from pilot.server.gradio_patch import Chatbot as grChatbot
|
||||
@ -49,6 +45,19 @@ from pilot.vector_store.extract_tovec import (
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.language.translation_handler import get_lang_text
|
||||
from pilot.server.webserver_base import server_init
|
||||
|
||||
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
# 加载插件
|
||||
CFG = Config()
|
||||
@ -95,6 +104,19 @@ knowledge_qa_type_list = [
|
||||
add_knowledge_base_dialogue,
|
||||
]
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
)
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
||||
def get_simlar(q):
|
||||
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
|
||||
@ -324,15 +346,14 @@ def http_bot(
|
||||
response = chat.stream_call()
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
state.messages[-1][
|
||||
-1
|
||||
] = chat.prompt_template.output_parser.parse_model_stream_resp_ex(
|
||||
chunk, chat.skip_echo_len
|
||||
)
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
state.messages[-1][-1] =msg
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
chat.memory.append(chat.current_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
state.messages[-1][-1] = "Error:" + str(e)
|
||||
state.messages[-1][-1] = f"""<span style=\"color:red\">ERROR!</span>{str(e)} """
|
||||
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
||||
|
||||
|
||||
@ -632,7 +653,7 @@ def knowledge_embedding_store(vs_id, files):
|
||||
)
|
||||
knowledge_embedding_client = KnowledgeEmbedding(
|
||||
file_path=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, vs_id, filename),
|
||||
model_name=LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL],
|
||||
model_name=LLM_MODEL_CONFIG["text2vec"],
|
||||
vector_store_config={
|
||||
"vector_store_name": vector_store_name["vs_name"],
|
||||
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
@ -657,42 +678,25 @@ def signal_handler(sig, frame):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
parser.add_argument('-new', '--new', action='store_true', help='enable new http mode')
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument(
|
||||
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
logger.info(f"args: {args}")
|
||||
server_init(args)
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
|
||||
load_native_plugins(cfg)
|
||||
dbs = cfg.local_db.get_database_list()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
"pilot.commands.built_in.audio_text",
|
||||
"pilot.commands.built_in.image_gen",
|
||||
]
|
||||
# exclude commands
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry = command_registry
|
||||
|
||||
logger.info(args)
|
||||
if args.new:
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
||||
else:
|
||||
### Compatibility mode starts the old version server by default
|
||||
demo = build_webdemo()
|
||||
demo.queue(
|
||||
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
||||
@ -702,3 +706,8 @@ if __name__ == "__main__":
|
||||
share=args.share,
|
||||
max_threads=200,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
60
pilot/server/webserver_base.py
Normal file
60
pilot/server/webserver_base.py
Normal file
@ -0,0 +1,60 @@
|
||||
import signal
|
||||
import os
|
||||
import threading
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
from pilot.common.plugins import scan_plugins, load_native_plugins
|
||||
from pilot.utils import build_logger
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
print("in order to avoid chroma db atexit problem")
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def async_db_summery():
|
||||
client = DBSummaryClient()
|
||||
thread = threading.Thread(target=client.init_db_summary)
|
||||
thread.start()
|
||||
|
||||
|
||||
def server_init(args):
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
cfg = Config()
|
||||
|
||||
load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
cfg.set_plugins(scan_plugins(cfg, cfg.debug_mode))
|
||||
|
||||
# Loader plugins and commands
|
||||
command_categories = [
|
||||
"pilot.commands.built_in.audio_text",
|
||||
"pilot.commands.built_in.image_gen",
|
||||
]
|
||||
# exclude commands
|
||||
command_categories = [
|
||||
x for x in command_categories if x not in cfg.disabled_command_categories
|
||||
]
|
||||
command_registry = CommandRegistry()
|
||||
for command_category in command_categories:
|
||||
command_registry.import_commands(command_category)
|
||||
|
||||
cfg.command_registry = command_registry
|
@ -35,7 +35,7 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
}
|
||||
self._headers = {
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": cfg.elevenlabs_api_key,
|
||||
"xi-api_v1-key": cfg.elevenlabs_api_key,
|
||||
}
|
||||
self._voices = default_voices.copy()
|
||||
if cfg.elevenlabs_voice_1_id in voice_options:
|
||||
|
@ -1,13 +1,12 @@
|
||||
from pilot.vector_store.chroma_store import ChromaStore
|
||||
|
||||
from pilot.vector_store.milvus_store import MilvusStore
|
||||
from pilot.vector_store.weaviate_store import WeaviateStore
|
||||
# from pilot.vector_store.milvus_store import MilvusStore
|
||||
|
||||
connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
|
||||
connector = {"Chroma": ChromaStore, "Milvus": None}
|
||||
|
||||
|
||||
class VectorStoreConnector:
|
||||
"""vector store connector, can connect different vector db provided load document api and similar search api."""
|
||||
"""vector store connector, can connect different vector db provided load document api_v1 and similar search api_v1."""
|
||||
|
||||
def __init__(self, vector_store_type, ctx: {}) -> None:
|
||||
"""initialize vector store connector."""
|
||||
|
@ -1,146 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import weaviate
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Weaviate
|
||||
from weaviate.exceptions import WeaviateBaseError
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.vector_store_base import VectorStoreBase
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class WeaviateStore(VectorStoreBase):
|
||||
"""Weaviate database"""
|
||||
|
||||
def __init__(self, ctx: dict) -> None:
|
||||
"""Initialize with Weaviate client."""
|
||||
try:
|
||||
import weaviate
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
|
||||
self.ctx = ctx
|
||||
self.weaviate_url = CFG.WEAVIATE_URL
|
||||
self.embedding = ctx.get("embeddings", None)
|
||||
self.vector_name = ctx["vector_store_name"]
|
||||
self.persist_dir = os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, self.vector_name + ".vectordb"
|
||||
)
|
||||
|
||||
self.vector_store_client = weaviate.Client(self.weaviate_url)
|
||||
|
||||
def similar_search(self, text: str, topk: int) -> None:
|
||||
"""Perform similar search in Weaviate"""
|
||||
logger.info("Weaviate similar search")
|
||||
# nearText = {
|
||||
# "concepts": [text],
|
||||
# "distance": 0.75, # prior to v1.14 use "certainty" instead of "distance"
|
||||
# }
|
||||
# vector = self.embedding.embed_query(text)
|
||||
response = (
|
||||
self.vector_store_client.query.get(
|
||||
self.vector_name, ["metadata", "page_content"]
|
||||
)
|
||||
# .with_near_vector({"vector": vector})
|
||||
.with_limit(topk).do()
|
||||
)
|
||||
res = response["data"]["Get"][list(response["data"]["Get"].keys())[0]]
|
||||
docs = []
|
||||
for r in res:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=r["page_content"],
|
||||
metadata={"metadata": r["metadata"]},
|
||||
)
|
||||
)
|
||||
return docs
|
||||
|
||||
def vector_name_exists(self) -> bool:
|
||||
"""Check if a vector name exists for a given class in Weaviate.
|
||||
Returns:
|
||||
bool: True if the vector name exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if self.vector_store_client.schema.get(self.vector_name):
|
||||
return True
|
||||
return False
|
||||
except WeaviateBaseError as e:
|
||||
logger.error("vector_name_exists error", e.message)
|
||||
return False
|
||||
|
||||
def _default_schema(self) -> None:
|
||||
"""
|
||||
Create the schema for Weaviate with a Document class containing metadata and text properties.
|
||||
"""
|
||||
|
||||
schema = {
|
||||
"classes": [
|
||||
{
|
||||
"class": self.vector_name,
|
||||
"description": "A document with metadata and text",
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "poolingStrategy": "masked_mean",
|
||||
# "vectorizeClassName": False,
|
||||
# }
|
||||
# },
|
||||
"properties": [
|
||||
{
|
||||
"dataType": ["text"],
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "skip": False,
|
||||
# "vectorizePropertyName": False,
|
||||
# }
|
||||
# },
|
||||
"description": "Metadata of the document",
|
||||
"name": "metadata",
|
||||
},
|
||||
{
|
||||
"dataType": ["text"],
|
||||
# "moduleConfig": {
|
||||
# "text2vec-transformers": {
|
||||
# "skip": False,
|
||||
# "vectorizePropertyName": False,
|
||||
# }
|
||||
# },
|
||||
"description": "Text content of the document",
|
||||
"name": "page_content",
|
||||
},
|
||||
],
|
||||
# "vectorizer": "text2vec-transformers",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Create the schema in Weaviate
|
||||
self.vector_store_client.schema.create(schema)
|
||||
|
||||
def load_document(self, documents: list) -> None:
|
||||
"""Load documents into Weaviate"""
|
||||
logger.info("Weaviate load document")
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
|
||||
# Import data
|
||||
with self.vector_store_client.batch as batch:
|
||||
batch.batch_size = 100
|
||||
|
||||
# Batch import all documents
|
||||
for i in range(len(texts)):
|
||||
properties = {
|
||||
"metadata": metadatas[i]["source"],
|
||||
"page_content": texts[i],
|
||||
}
|
||||
|
||||
self.vector_store_client.batch.add_data_object(
|
||||
data_object=properties, class_name=self.vector_name
|
||||
)
|
||||
self.vector_store_client.batch.flush()
|
Loading…
Reference in New Issue
Block a user