Merge branch 'main' into dev

# Conflicts:
#	pilot/connections/mysql.py
#	pilot/prompts/prompt_generator.py
#	pilot/scene/base_chat.py
#	pilot/scene/chat_db/chat.py
#	pilot/scene/chat_db/out_parser.py
#	pilot/scene/chat_execution/chat.py
#	pilot/server/webserver.py
This commit is contained in:
yhjun1026
2023-05-30 10:57:57 +08:00
21 changed files with 420 additions and 204 deletions

View File

@@ -17,6 +17,9 @@ class Config(metaclass=Singleton):
def __init__(self) -> None:
"""Initialize the Config class"""
# Gradio language version: en, cn
self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.debug_mode = False
self.skip_reprompt = False
self.temperature = float(os.getenv("TEMPERATURE", 0.7))

View File

@@ -5,6 +5,7 @@ import dataclasses
import uuid
from enum import auto, Enum
from typing import List, Any
from pilot.language.translation_handler import get_lang_text
from pilot.configs.config import Config
@@ -263,15 +264,17 @@ conv_qa_prompt_template = """ 基于以下已知的信息, 专业、简要的回
default_conversation = conv_one_shot
conversation_sql_mode = {
"auto_execute_ai_response": "直接执行结果",
"dont_execute_ai_response": "不直接执行结果",
"auto_execute_ai_response": get_lang_text("sql_generate_mode_direct"),
"dont_execute_ai_response": get_lang_text("sql_generate_mode_none"),
}
conversation_types = {
"native": "LLM原生对话",
"default_knownledge": "默认知识库对话",
"custome": "新增知识库对话",
"auto_execute_plugin": "对话使用插件",
"native": get_lang_text("knowledge_qa_type_llm_native_dialogue"),
"default_knownledge": get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue"
),
"custome": get_lang_text("knowledge_qa_type_add_knowledge_base_dialogue"),
"auto_execute_plugin": get_lang_text("dialogue_use_plugin"),
}
conv_templates = {

View File

@@ -0,0 +1,74 @@
## 短期内在该文件中配置,长期考虑将会存储在默认的数据库中存储,并可以支持多种语言的配置
lang_dicts = {
"zh": {
"unique_id": "中文内容",
"db_gpt_introduction": "[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目使用本地化的GPT大模型与您的数据和环境进行交互无数据泄露风险100% 私密100% 安全。",
"learn_more_markdown": "该服务是仅供非商业用途的研究预览。受 Vicuna-13B 模型 [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) 的约束",
"model_control_param": "模型参数",
"sql_generate_mode_direct": "直接执行结果",
"sql_generate_mode_none": "不直接执行结果",
"max_input_token_size": "最大输出Token数",
"please_choose_database": "请选择数据",
"sql_generate_diagnostics": "SQL生成与诊断",
"knowledge_qa_type_llm_native_dialogue": "LLM原生对话",
"knowledge_qa_type_default_knowledge_base_dialogue": "默认知识库对话",
"knowledge_qa_type_add_knowledge_base_dialogue": "新增知识库对话",
"dialogue_use_plugin": "对话使用插件",
"create_knowledge_base": "新建知识库",
"sql_schema_info": "数据库{}的Schema信息如下: {}\n",
"current_dialogue_mode": "当前对话模式",
"database_smart_assistant": "数据库智能助手",
"sql_vs_setting": "自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力",
"knowledge_qa": "知识问答",
"configure_knowledge_base": "配置知识库",
"new_klg_name": "新知识库名称",
"add_as_new_klg": "添加为新知识库",
"add_file_to_klg": "向知识库中添加文件",
"upload_file": "上传文件",
"add_file": "添加文件",
"upload_and_load_to_klg": "上传并加载到知识库",
"upload_folder": "上传文件夹",
"add_folder": "添加文件夹",
"send": "发送",
"regenerate": "重新生成",
"clear_box": "清理",
},
"en": {
"unique_id": "English Content",
"db_gpt_introduction": "[DB-GPT](https://github.com/csunny/DB-GPT) is an experimental open-source project that uses localized GPT large models to interact with your data and environment. With this solution, you can be assured that there is no risk of data leakage, and your data is 100% private and secure.",
"learn_more_markdown": "The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B",
"model_control_param": "Model Parameters",
"sql_generate_mode_direct": "Execute directly",
"sql_generate_mode_none": "Execute without model",
"max_input_token_size": "Maximum output token size",
"please_choose_database": "Please choose database",
"sql_generate_diagnostics": "SQL Generation & Diagnostics",
"knowledge_qa_type_llm_native_dialogue": "LLM native dialogue",
"knowledge_qa_type_default_knowledge_base_dialogue": "Default documents",
"knowledge_qa_type_add_knowledge_base_dialogue": "Added documents",
"dialogue_use_plugin": "Dialogue Extension",
"create_knowledge_base": "Create Knowledge Base",
"sql_schema_info": "the schema information of database {}: {}\n",
"current_dialogue_mode": "Current dialogue mode",
"database_smart_assistant": "Database smart assistant",
"sql_vs_setting": "In the automatic execution mode, DB-GPT can have the ability to execute SQL, read data from the network, automatically store and learn",
"knowledge_qa": "Documents QA",
"configure_knowledge_base": "Configure Documents",
"new_klg_name": "New document name",
"add_as_new_klg": "Add as new documents",
"add_file_to_klg": "Add file to documents",
"upload_file": "Upload file",
"add_file": "Add file",
"upload_and_load_to_klg": "Upload and load to documents",
"upload_folder": "Upload folder",
"add_folder": "Add folder",
"send": "Send",
"regenerate": "Regenerate",
"clear_box": "Clear",
},
}
def get_lang_content(key, language="zh"):
return lang_dicts.get(language, {}).get(key, "")

View File

@@ -0,0 +1,8 @@
from pilot.configs.config import Config
from pilot.language.lang_content_mapping import get_lang_content
CFG = Config()
def get_lang_text(key):
return get_lang_content(key, CFG.LANGUAGE)

View File

@@ -5,6 +5,7 @@ from langchain.prompts import PromptTemplate
from pilot.configs.model_config import VECTOR_SEARCH_TOP_K
from pilot.conversation import conv_qa_prompt_template
from pilot.logs import logger
from pilot.model.vicuna_llm import VicunaLLM
from pilot.vector_store.file_loader import KnownLedge2Vector
@@ -28,3 +29,27 @@ class KnownLedgeBaseQA:
context = [d.page_content for d in docs]
result = prompt.format(context="\n".join(context), question=query)
return result
@staticmethod
def build_knowledge_prompt(query, docs, state):
prompt_template = PromptTemplate(
template=conv_qa_prompt_template, input_variables=["context", "question"]
)
context = [d.page_content for d in docs]
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
if len(prompt) > 4000:
logger.info("prompt length greater than 4000, rebuild")
context = context[:2000]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
print("new prompt length:" + str(len(prompt)))
return prompt

View File

@@ -13,7 +13,7 @@ from urllib.parse import urljoin
import gradio as gr
import requests
from langchain import PromptTemplate
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -56,7 +56,11 @@ from pilot.vector_store.extract_tovec import (
from pilot.commands.command import execute_ai_response_json
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.language.translation_handler import get_lang_text
# 加载插件
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
headers = {"User-Agent": "dbgpt Client"}
@@ -67,15 +71,13 @@ disable_btn = gr.Button.update(interactive=True)
enable_moderation = False
models = []
dbs = []
vs_list = ["新建知识库"] + get_vector_storelist()
vs_list = [get_lang_text("create_knowledge_base")] + get_vector_storelist()
autogpt = False
vector_store_client = None
vector_store_name = {"vs_name": ""}
priority = {"vicuna-13b": "aaa"}
# 加载插件
CFG = Config()
CHAT_FACTORY = ChatFactory()
DB_SETTINGS = {
@@ -86,6 +88,20 @@ DB_SETTINGS = {
}
llm_native_dialogue = get_lang_text("knowledge_qa_type_llm_native_dialogue")
default_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_default_knowledge_base_dialogue"
)
add_knowledge_base_dialogue = get_lang_text(
"knowledge_qa_type_add_knowledge_base_dialogue"
)
knowledge_qa_type_list = [
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
]
def get_simlar(q):
docsearch = knownledge_tovec_st(os.path.join(DATASETS_DIR, "plan.md"))
docs = docsearch.similarity_search_with_score(q, k=1)
@@ -100,7 +116,7 @@ def gen_sqlgen_conversation(dbname):
schemas = CFG.local_db.table_simple_info(db_connect)
for s in schemas:
message += s["schema_info"] + ";"
return f"数据库{dbname}的Schema信息如下: {message}\n"
return get_lang_text("sql_schema_info").format(dbname, message)
def plugins_select_info():
@@ -127,6 +143,7 @@ function() {
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
# dbs = get_database_list()
dropdown_update = gr.Dropdown.update(visible=True)
if dbs:
gr.Dropdown.update(choices=dbs)
@@ -213,7 +230,7 @@ def http_bot(
):
logger.info(f"User message send!{state.conv_id},{selected},{mode},{sql_mode},{db_selector},{plugin_selector}")
start_tstamp = time.time()
scene: ChatScene = get_chat_mode(selected, mode, sql_mode, db_selector)
scene:ChatScene = get_chat_mode(mode, sql_mode, db_selector)
print(f"当前对话模式:{scene.value}")
model_name = CFG.LLM_MODEL
@@ -222,7 +239,7 @@ def http_bot(
chat_param = {
"chat_session_id": state.conv_id,
"db_name": db_selector,
"current_user_input": state.last_user_input,
"user_input": state.last_user_input,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(scene.value, **chat_param)
chat.call()
@@ -241,6 +258,7 @@ def http_bot(
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
else:
dbname = db_selector
# TODO 这里的请求需要拼接现有知识库, 使得其根据现有知识库作答, 所以prompt需要继续优化
if state.skip_next:
@@ -290,18 +308,25 @@ def http_bot(
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
},
)
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, 1)
context = [d.page_content for d in docs]
prompt_template = PromptTemplate(
template=conv_qa_prompt_template,
input_variables=["context", "question"],
print("vector store name: ", vector_store_name["vs_name"])
vector_store_config = {
"vector_store_name": vector_store_name["vs_name"],
"text_field": "content",
"vector_store_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
knowledge_embedding_client = KnowledgeEmbedding(
file_path="",
model_name=LLM_MODEL_CONFIG["text2vec"],
local_persist=False,
vector_store_config=vector_store_config,
)
result = prompt_template.format(context="\n".join(context), question=query)
state.messages[-2][1] = result
prompt = state.get_prompt()
query = state.messages[-2][1]
docs = knowledge_embedding_client.similar_search(query, VECTOR_SEARCH_TOP_K)
prompt = KnownLedgeBaseQA.build_knowledge_prompt(query, docs, state)
state.messages[-2][1] = query
skip_echo_len = len(prompt.replace("</s>", " ")) + 1
# Make requests
payload = {
"model": model_name,
@@ -346,7 +371,56 @@ def http_bot(
enable_btn,
)
return
try:
# Stream output
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=20,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + ""
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
) * 5
else:
output = (
data["text"] + f" (error_code: {data['error_code']})"
)
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg + f" (error_code: 4)"
yield (state, state.to_gradio_chatbot()) + (
@@ -358,24 +432,24 @@ def http_bot(
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
# 记录运行日志
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(start_tstamp, 4),
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = (
@@ -396,14 +470,14 @@ block_css = (
def change_sql_mode(sql_mode):
if sql_mode in ["直接执行结果"]:
if sql_mode in [get_lang_text("sql_generate_mode_direct")]:
return gr.update(visible=True)
else:
return gr.update(visible=False)
def change_mode(mode):
if mode in ["默认知识库对话", "LLM原生对话"]:
if mode in [default_knowledge_base_dialogue, llm_native_dialogue]:
return gr.update(visible=False)
else:
return gr.update(visible=True)
@@ -413,27 +487,16 @@ def change_tab():
autogpt = True
def change_func(xx):
print("123")
print(str(xx))
def build_single_model_ui():
notice_markdown = """
# DB-GPT
[DB-GPT](https://github.com/csunny/DB-GPT) 是一个开源的以数据库为基础的GPT实验项目使用本地化的GPT大模型与您的数据和环境进行交互无数据泄露风险100% 私密100% 安全。
"""
learn_more_markdown = """
### Licence
The service is a research preview intended for non-commercial use only. subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of Vicuna-13B
"""
notice_markdown = get_lang_text("db_gpt_introduction")
learn_more_markdown = get_lang_text("learn_more_markdown")
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
with gr.Accordion("参数", open=False, visible=False) as parameter_row:
with gr.Accordion(
get_lang_text("model_control_param"), open=False, visible=False
) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
@@ -449,7 +512,7 @@ def build_single_model_ui():
value=512,
step=64,
interactive=True,
label="最大输出Token数",
label=get_lang_text("max_input_token_size"),
)
tabs = gr.Tabs()
@@ -462,24 +525,30 @@ def build_single_model_ui():
tabs.select(on_select, None, selected)
with tabs:
tab_sql = gr.TabItem("SQL生成与诊断", elem_id="SQL")
tab_sql.select(on_select, None, None)
tab_sql = gr.TabItem(get_lang_text("sql_generate_diagnostics"), elem_id="SQL")
with tab_sql:
print("tab_sql in...")
# TODO A selector to choose database
with gr.Row(elem_id="db_selector"):
db_selector = gr.Dropdown(
label="请选择数据库",
label=get_lang_text("please_choose_database"),
choices=dbs,
value=dbs[0] if len(models) > 0 else "",
interactive=True,
show_label=True,
).style(container=False)
sql_mode = gr.Radio(["直接执行结果", "不执行结果"], show_label=False, value="不执行结果")
sql_vs_setting = gr.Markdown("自动执行模式下, DB-GPT可以具备执行SQL、从网络读取知识自动化存储学习的能力")
sql_mode = gr.Radio(
[
get_lang_text("sql_generate_mode_direct"),
get_lang_text("sql_generate_mode_none"),
],
show_label=False,
value=get_lang_text("sql_generate_mode_none"),
)
sql_vs_setting = gr.Markdown(get_lang_text("sql_vs_setting"))
sql_mode.change(fn=change_sql_mode, inputs=sql_mode, outputs=sql_vs_setting)
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
tab_plugin = gr.TabItem("插件模式", elem_id="PLUGIN")
# tab_plugin.select(change_func)
with tab_plugin:
@@ -502,37 +571,50 @@ def build_single_model_ui():
plugin_selected = gr.Textbox(show_label=False, visible=False, placeholder="Selected")
plugin_selector.select(plugin_change, None, plugin_selected)
tab_qa = gr.TabItem("知识问答", elem_id="QA")
tab_qa = gr.TabItem(get_lang_text("knowledge_qa"), elem_id="QA")
with tab_qa:
print("tab_qa in...")
mode = gr.Radio(
["LLM原生对话", "默认知识库对话", "新增知识库对话"], show_label=False, value="LLM原生对话"
[
llm_native_dialogue,
default_knowledge_base_dialogue,
add_knowledge_base_dialogue,
],
show_label=False,
value=llm_native_dialogue,
)
vs_setting = gr.Accordion(
get_lang_text("configure_knowledge_base"), open=False
)
vs_setting = gr.Accordion("配置知识库", open=False)
mode.change(fn=change_mode, inputs=mode, outputs=vs_setting)
with vs_setting:
vs_name = gr.Textbox(label="新知识库名称", lines=1, interactive=True)
vs_add = gr.Button("添加为新知识库")
vs_name = gr.Textbox(
label=get_lang_text("new_klg_name"), lines=1, interactive=True
)
vs_add = gr.Button(get_lang_text("add_as_new_klg"))
with gr.Column() as doc2vec:
gr.Markdown("向知识库中添加文件")
with gr.Tab("上传文件"):
gr.Markdown(get_lang_text("add_file_to_klg"))
with gr.Tab(get_lang_text("upload_file")):
files = gr.File(
label="添加文件",
label=get_lang_text("add_file"),
file_types=[".txt", ".md", ".docx", ".pdf"],
file_count="multiple",
allow_flagged_uploads=True,
show_label=False,
)
load_file_button = gr.Button("上传并加载到知识库")
with gr.Tab("上传文件夹"):
load_file_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Tab(get_lang_text("upload_folder")):
folder_files = gr.File(
label="添加文件夹",
label=get_lang_text("add_folder"),
accept_multiple_files=True,
file_count="directory",
show_label=False,
)
load_folder_button = gr.Button("上传并加载到知识库")
load_folder_button = gr.Button(
get_lang_text("upload_and_load_to_klg")
)
with gr.Blocks():
chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
@@ -544,11 +626,11 @@ def build_single_model_ui():
visible=False,
).style(container=False)
with gr.Column(scale=2, min_width=50):
send_btn = gr.Button(value="发送", visible=False)
send_btn = gr.Button(value=get_lang_text("send"), visible=False)
with gr.Row(visible=False) as button_row:
regenerate_btn = gr.Button(value="重新生成", interactive=False)
clear_btn = gr.Button(value="清理", interactive=False)
regenerate_btn = gr.Button(value=get_lang_text("regenerate"), interactive=False)
clear_btn = gr.Button(value=get_lang_text("clear_box"), interactive=False)
gr.Markdown(learn_more_markdown)
btn_list = [regenerate_btn, clear_btn]
@@ -594,10 +676,10 @@ def build_single_model_ui():
def build_webdemo():
with gr.Blocks(
title="数据库智能助手",
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
title=get_lang_text("database_smart_assistant"),
# theme=gr.themes.Base(),
theme=gr.themes.Default(),
css=block_css,
) as demo:
url_params = gr.JSON(visible=False)
(

View File

@@ -2,7 +2,7 @@ import os
import markdown
from bs4 import BeautifulSoup
from langchain.document_loaders import PyPDFLoader, TextLoader, markdown
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config

View File

@@ -2,9 +2,6 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector
@@ -35,9 +32,7 @@ class SourceEmbedding(ABC):
self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
vector_store_config["embeddings"] = self.embeddings
self.embeddings = vector_store_config["embeddings"]
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, vector_store_config
)

View File

@@ -1,11 +1,12 @@
from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections
from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
@@ -29,6 +30,7 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
self.alias = "default"
# use HNSW by default.
self.index_params = {
@@ -48,7 +50,9 @@ class MilvusStore(VectorStoreBase):
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
# default collection schema
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content"
if (self.username is None) != (self.password is None):
@@ -98,56 +102,43 @@ class MilvusStore(VectorStoreBase):
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name):
self.col = Collection(self.collection_name, using=self.alias)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
self._add_documents(texts, metadatas)
return self.collection_name
dim = len(embeddings)
# Generate unique names
primary_field = "pk_id"
vector_field = "vector"
text_field = "content"
self.text_field = text_field
primary_field = self.primary_field
vector_field = self.vector_field
text_field = self.text_field
# self.text_field = text_field
collection_name = vector_name
fields = []
# Determine metadata schema
# if metadatas:
# # Check if all metadata keys line up
# key = metadatas[0].keys()
# for x in metadatas:
# if key != x.keys():
# raise ValueError(
# "Mismatched metadata. "
# "Make sure all metadata has the same keys and datatype."
# )
# # Create FieldSchema for each entry in singular metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# if dtype == DataType.UNKNOWN:
# raise ValueError(f"Unrecognized datatype for {key}.")
# elif dtype == DataType.VARCHAR:
# # Find out max length text based metadata
# max_length = 0
# for subvalues in metadatas:
# max_length = max(max_length, len(subvalues[key]))
# fields.append(
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
# )
# else:
# fields.append(FieldSchema(key, dtype))
# Find out max length of texts
max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# Create the text field
fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
)
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
# primary key field
fields.append(
FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True)
)
# vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# milvus the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
@@ -165,7 +156,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_texts(texts, metadatas)
self._add_documents(texts, metadatas)
return self.collection_name
@@ -224,7 +215,7 @@ class MilvusStore(VectorStoreBase):
# )
# return _text
def _add_texts(
def _add_documents(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
@@ -257,7 +248,14 @@ class MilvusStore(VectorStoreBase):
def load_document(self, documents) -> None:
"""load document in vector database."""
self.init_schema_and_load(self.collection_name, documents)
# self.init_schema_and_load(self.collection_name, documents)
batch_size = 500
batched_list = [
documents[i : i + batch_size] for i in range(0, len(documents), batch_size)
]
# docs = []
for doc_batch in batched_list:
self.init_schema_and_load(self.collection_name, doc_batch)
def similar_search(self, text, topk) -> None:
"""similar_search in vector database."""
@@ -320,3 +318,6 @@ class MilvusStore(VectorStoreBase):
)
return data[0], ret
def close(self):
connections.disconnect()