feat: support multiple knowledge file path and skip some error in knowledge embedding (#379)

Close #353 

Additional, I created DB-GPT command line tool scripts which currently
supports initialization of the knowledge documents.
And we can support more command line in the future.
This commit is contained in:
Aries-ckt 2023-07-28 17:23:21 +08:00 committed by GitHub
commit 166c914922
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 300 additions and 7 deletions

View File

@ -31,6 +31,8 @@ python tools/knowledge_init.py
```
Optionally, you can run `python tools/knowledge_init.py -h` command to see more usage.
3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base.
Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.

121
tools/cli/cli_scripts.py Normal file
View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DB-GPT command line tools.
You can use it for some background management:
- Lots of knowledge document initialization.
- Load the data into the database.
- Show server status
- ...
Maybe move this to pilot module and append to console_scripts in the future.
"""
import sys
import click
import os
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
from pilot.configs.model_config import DATASETS_DIR
from tools.cli.knowledge_client import knowledge_init
API_ADDRESS: str = "http://127.0.0.1:5000"
@click.group()
@click.option(
"--api_address",
required=False,
default="http://127.0.0.1:5000",
type=str,
help="Api server address",
)
@click.version_option()
def cli(api_address: str):
global API_ADDRESS
API_ADDRESS = api_address
@cli.command()
@click.option(
"--vector_name",
required=False,
type=str,
default="default",
help="Your vector store name",
)
@click.option(
"--vector_store_type",
required=False,
type=str,
default="Chroma",
help="Vector store type",
)
@click.option(
"--local_doc_dir",
required=False,
type=str,
default=DATASETS_DIR,
help="Your document directory",
)
@click.option(
"--skip_wrong_doc",
required=False,
type=bool,
default=False,
help="Skip wrong document",
)
@click.option(
"--max_workers",
required=False,
type=int,
default=None,
help="The maximum number of threads that can be used to upload document",
)
@click.option(
"-v",
"--verbose",
required=False,
is_flag=True,
hidden=True,
help="Show debuggging information.",
)
def knowledge(
vector_name: str,
vector_store_type: str,
local_doc_dir: str,
skip_wrong_doc: bool,
max_workers: int,
verbose: bool,
):
"""Knowledge command line tool"""
knowledge_init(
API_ADDRESS,
vector_name,
vector_store_type,
local_doc_dir,
skip_wrong_doc,
verbose,
max_workers,
)
# knowledge command
cli.add_command(knowledge)
# TODO add more command
def main():
return cli()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,147 @@
import os
import requests
import json
from urllib.parse import urljoin
from concurrent.futures import ThreadPoolExecutor, as_completed
from pilot.openapi.api_v1.api_view_model import Result
from pilot.server.knowledge.request.request import (
KnowledgeQueryRequest,
KnowledgeDocumentRequest,
DocumentSyncRequest,
ChunkQueryRequest,
DocumentQueryRequest,
)
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.server.knowledge.request.request import DocumentSyncRequest
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
HTTP_HEADERS = {"Content-Type": "application/json"}
class ApiClient:
def __init__(self, api_address: str) -> None:
self.api_address = api_address
def _handle_response(self, response):
if 200 <= response.status_code <= 300:
result = Result(**response.json())
if not result.success:
raise Exception(result.err_msg)
return result.data
else:
raise Exception(
f"Http request error, code: {response.status_code}, message: {response.text}"
)
def _post(self, url: str, data=None):
if not isinstance(data, dict):
data = data.__dict__
response = requests.post(
urljoin(self.api_address, url), data=json.dumps(data), headers=HTTP_HEADERS
)
return self._handle_response(response)
class KnowledgeApiClient(ApiClient):
def __init__(self, api_address: str) -> None:
super().__init__(api_address)
def space_add(self, request: KnowledgeSpaceRequest):
try:
return self._post("/knowledge/space/add", data=request)
except Exception as e:
if "have already named" in str(e):
print(f"Warning: you have already named {request.name}")
else:
raise e
def space_list(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/list", data=request)
def document_add(self, space_name: str, request: KnowledgeDocumentRequest):
url = f"/knowledge/{space_name}/document/add"
return self._post(url, data=request)
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
url = f"/knowledge/{space_name}/document/list"
return self._post(url, data=query_request)
def document_upload(self, space_name, doc_name, doc_type, doc_file_path):
"""Upload with multipart/form-data"""
url = f"{self.api_address}/knowledge/{space_name}/document/upload"
with open(doc_file_path, "rb") as f:
files = {"doc_file": f}
data = {"doc_name": doc_name, "doc_type": doc_type}
response = requests.post(url, data=data, files=files)
return self._handle_response(response)
def document_sync(self, space_name: str, request: DocumentSyncRequest):
url = f"/knowledge/{space_name}/document/sync"
return self._post(url, data=request)
def chunk_list(self, space_name: str, query_request: ChunkQueryRequest):
url = f"/knowledge/{space_name}/chunk/list"
return self._post(url, data=query_request)
def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest):
url = f"/knowledge/{vector_name}/query"
return self._post(url, data=query_request)
def knowledge_init(
api_address: str,
vector_name: str,
vector_store_type: str,
local_doc_dir: str,
skip_wrong_doc: bool,
verbose: bool,
max_workers: int = None,
):
client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest()
space.name = vector_name
space.desc = "DB-GPT cli"
space.vector_type = vector_store_type
space.owner = "DB-GPT"
# Create space
print(f"Create space: {space}")
client.space_add(space)
print("Create space successfully")
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
if len(space_list) != 1:
raise Exception(f"List space {space.name} error")
space = KnowledgeSpaceRequest(**space_list[0])
doc_ids = []
def upload(filename: str):
try:
print(f"Begin upload document: {filename} to {space.name}")
return client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
except Exception as e:
if skip_wrong_doc:
print(f"Warning: {str(e)}")
else:
raise e
with ThreadPoolExecutor(max_workers=max_workers) as pool:
tasks = []
for root, _, files in os.walk(local_doc_dir, topdown=False):
for file in files:
filename = os.path.join(root, file)
tasks.append(pool.submit(upload, filename))
doc_ids = [r.result() for r in as_completed(tasks)]
doc_ids = list(filter(lambda x: x, doc_ids))
if not doc_ids:
print("Warning: no document to sync")
return
print(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@ -3,6 +3,7 @@
import argparse
import os
import sys
import traceback
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
@ -31,7 +32,7 @@ class LocalKnowledgeInit:
self.vector_store_config = vector_store_config
self.model_name = LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
def knowledge_persist(self, file_path):
def knowledge_persist(self, file_path: str, skip_wrong_doc: bool = False):
"""knowledge persist"""
docs = []
embedding_engine = None
@ -44,9 +45,18 @@ class LocalKnowledgeInit:
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
embedding_engine = ke.init_knowledge_embedding()
doc = ke.read()
docs.extend(doc)
try:
embedding_engine = ke.init_knowledge_embedding()
doc = ke.read()
docs.extend(doc)
except Exception as e:
error_msg = traceback.format_exc()
if skip_wrong_doc:
print(
f"Warning: document file {filename} embedding error, skip it, error message: {error_msg}"
)
else:
raise e
embedding_engine.index_to_store(docs)
print(f"""begin create {self.vector_store_config["vector_store_name"]} space""")
try:
@ -64,11 +74,24 @@ class LocalKnowledgeInit:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vector_name", type=str, default="default")
# TODO https://github.com/csunny/DB-GPT/issues/354
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--vector_name", type=str, default="default", help="Your vector store name"
)
parser.add_argument(
"--file_path", type=str, default=DATASETS_DIR, help="Your document path"
)
parser.add_argument(
"--skip_wrong_doc", type=bool, default=False, help="Skip wrong document"
)
args = parser.parse_args()
vector_name = args.vector_name
store_type = CFG.VECTOR_STORE_TYPE
file_path = args.file_path
skip_wrong_doc = args.skip_wrong_doc
vector_store_config = {
"vector_store_name": vector_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
@ -76,5 +99,5 @@ if __name__ == "__main__":
}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
kv.knowledge_persist(file_path=DATASETS_DIR)
kv.knowledge_persist(file_path=file_path, skip_wrong_doc=skip_wrong_doc)
print("your knowledge embedding success...")