mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-07 11:23:40 +00:00
feat: support multiple knowledge file path and skip some error in knowledge embedding (#379)
Close #353 Additional, I created DB-GPT command line tool scripts which currently supports initialization of the knowledge documents. And we can support more command line in the future.
This commit is contained in:
commit
166c914922
@ -31,6 +31,8 @@ python tools/knowledge_init.py
|
||||
|
||||
```
|
||||
|
||||
Optionally, you can run `python tools/knowledge_init.py -h` command to see more usage.
|
||||
|
||||
3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base.
|
||||
|
||||
Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory.
|
121
tools/cli/cli_scripts.py
Normal file
121
tools/cli/cli_scripts.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
DB-GPT command line tools.
|
||||
|
||||
You can use it for some background management:
|
||||
- Lots of knowledge document initialization.
|
||||
- Load the data into the database.
|
||||
- Show server status
|
||||
- ...
|
||||
|
||||
|
||||
Maybe move this to pilot module and append to console_scripts in the future.
|
||||
|
||||
"""
|
||||
import sys
|
||||
import click
|
||||
import os
|
||||
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
)
|
||||
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR
|
||||
|
||||
from tools.cli.knowledge_client import knowledge_init
|
||||
|
||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--api_address",
|
||||
required=False,
|
||||
default="http://127.0.0.1:5000",
|
||||
type=str,
|
||||
help="Api server address",
|
||||
)
|
||||
@click.version_option()
|
||||
def cli(api_address: str):
|
||||
global API_ADDRESS
|
||||
API_ADDRESS = api_address
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--vector_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default="default",
|
||||
help="Your vector store name",
|
||||
)
|
||||
@click.option(
|
||||
"--vector_store_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default="Chroma",
|
||||
help="Vector store type",
|
||||
)
|
||||
@click.option(
|
||||
"--local_doc_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=DATASETS_DIR,
|
||||
help="Your document directory",
|
||||
)
|
||||
@click.option(
|
||||
"--skip_wrong_doc",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Skip wrong document",
|
||||
)
|
||||
@click.option(
|
||||
"--max_workers",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="The maximum number of threads that can be used to upload document",
|
||||
)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
required=False,
|
||||
is_flag=True,
|
||||
hidden=True,
|
||||
help="Show debuggging information.",
|
||||
)
|
||||
def knowledge(
|
||||
vector_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_dir: str,
|
||||
skip_wrong_doc: bool,
|
||||
max_workers: int,
|
||||
verbose: bool,
|
||||
):
|
||||
"""Knowledge command line tool"""
|
||||
knowledge_init(
|
||||
API_ADDRESS,
|
||||
vector_name,
|
||||
vector_store_type,
|
||||
local_doc_dir,
|
||||
skip_wrong_doc,
|
||||
verbose,
|
||||
max_workers,
|
||||
)
|
||||
|
||||
|
||||
# knowledge command
|
||||
cli.add_command(knowledge)
|
||||
# TODO add more command
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
147
tools/cli/knowledge_client.py
Normal file
147
tools/cli/knowledge_client.py
Normal file
@ -0,0 +1,147 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
|
||||
from urllib.parse import urljoin
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from pilot.openapi.api_v1.api_view_model import Result
|
||||
from pilot.server.knowledge.request.request import (
|
||||
KnowledgeQueryRequest,
|
||||
KnowledgeDocumentRequest,
|
||||
DocumentSyncRequest,
|
||||
ChunkQueryRequest,
|
||||
DocumentQueryRequest,
|
||||
)
|
||||
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.server.knowledge.request.request import DocumentSyncRequest
|
||||
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
|
||||
HTTP_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, api_address: str) -> None:
|
||||
self.api_address = api_address
|
||||
|
||||
def _handle_response(self, response):
|
||||
if 200 <= response.status_code <= 300:
|
||||
result = Result(**response.json())
|
||||
if not result.success:
|
||||
raise Exception(result.err_msg)
|
||||
return result.data
|
||||
else:
|
||||
raise Exception(
|
||||
f"Http request error, code: {response.status_code}, message: {response.text}"
|
||||
)
|
||||
|
||||
def _post(self, url: str, data=None):
|
||||
if not isinstance(data, dict):
|
||||
data = data.__dict__
|
||||
response = requests.post(
|
||||
urljoin(self.api_address, url), data=json.dumps(data), headers=HTTP_HEADERS
|
||||
)
|
||||
return self._handle_response(response)
|
||||
|
||||
|
||||
class KnowledgeApiClient(ApiClient):
|
||||
def __init__(self, api_address: str) -> None:
|
||||
super().__init__(api_address)
|
||||
|
||||
def space_add(self, request: KnowledgeSpaceRequest):
|
||||
try:
|
||||
return self._post("/knowledge/space/add", data=request)
|
||||
except Exception as e:
|
||||
if "have already named" in str(e):
|
||||
print(f"Warning: you have already named {request.name}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
def space_list(self, request: KnowledgeSpaceRequest):
|
||||
return self._post("/knowledge/space/list", data=request)
|
||||
|
||||
def document_add(self, space_name: str, request: KnowledgeDocumentRequest):
|
||||
url = f"/knowledge/{space_name}/document/add"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
|
||||
url = f"/knowledge/{space_name}/document/list"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
def document_upload(self, space_name, doc_name, doc_type, doc_file_path):
|
||||
"""Upload with multipart/form-data"""
|
||||
url = f"{self.api_address}/knowledge/{space_name}/document/upload"
|
||||
with open(doc_file_path, "rb") as f:
|
||||
files = {"doc_file": f}
|
||||
data = {"doc_name": doc_name, "doc_type": doc_type}
|
||||
response = requests.post(url, data=data, files=files)
|
||||
return self._handle_response(response)
|
||||
|
||||
def document_sync(self, space_name: str, request: DocumentSyncRequest):
|
||||
url = f"/knowledge/{space_name}/document/sync"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def chunk_list(self, space_name: str, query_request: ChunkQueryRequest):
|
||||
url = f"/knowledge/{space_name}/chunk/list"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest):
|
||||
url = f"/knowledge/{vector_name}/query"
|
||||
return self._post(url, data=query_request)
|
||||
|
||||
|
||||
def knowledge_init(
|
||||
api_address: str,
|
||||
vector_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_dir: str,
|
||||
skip_wrong_doc: bool,
|
||||
verbose: bool,
|
||||
max_workers: int = None,
|
||||
):
|
||||
client = KnowledgeApiClient(api_address)
|
||||
space = KnowledgeSpaceRequest()
|
||||
space.name = vector_name
|
||||
space.desc = "DB-GPT cli"
|
||||
space.vector_type = vector_store_type
|
||||
space.owner = "DB-GPT"
|
||||
|
||||
# Create space
|
||||
print(f"Create space: {space}")
|
||||
client.space_add(space)
|
||||
print("Create space successfully")
|
||||
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
|
||||
if len(space_list) != 1:
|
||||
raise Exception(f"List space {space.name} error")
|
||||
space = KnowledgeSpaceRequest(**space_list[0])
|
||||
|
||||
doc_ids = []
|
||||
|
||||
def upload(filename: str):
|
||||
try:
|
||||
print(f"Begin upload document: {filename} to {space.name}")
|
||||
return client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
except Exception as e:
|
||||
if skip_wrong_doc:
|
||||
print(f"Warning: {str(e)}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
tasks = []
|
||||
for root, _, files in os.walk(local_doc_dir, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
tasks.append(pool.submit(upload, filename))
|
||||
doc_ids = [r.result() for r in as_completed(tasks)]
|
||||
doc_ids = list(filter(lambda x: x, doc_ids))
|
||||
if not doc_ids:
|
||||
print("Warning: no document to sync")
|
||||
return
|
||||
print(f"Begin sync document: {doc_ids}")
|
||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))
|
@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
@ -31,7 +32,7 @@ class LocalKnowledgeInit:
|
||||
self.vector_store_config = vector_store_config
|
||||
self.model_name = LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
|
||||
def knowledge_persist(self, file_path):
|
||||
def knowledge_persist(self, file_path: str, skip_wrong_doc: bool = False):
|
||||
"""knowledge persist"""
|
||||
docs = []
|
||||
embedding_engine = None
|
||||
@ -44,9 +45,18 @@ class LocalKnowledgeInit:
|
||||
model_name=self.model_name,
|
||||
vector_store_config=self.vector_store_config,
|
||||
)
|
||||
embedding_engine = ke.init_knowledge_embedding()
|
||||
doc = ke.read()
|
||||
docs.extend(doc)
|
||||
try:
|
||||
embedding_engine = ke.init_knowledge_embedding()
|
||||
doc = ke.read()
|
||||
docs.extend(doc)
|
||||
except Exception as e:
|
||||
error_msg = traceback.format_exc()
|
||||
if skip_wrong_doc:
|
||||
print(
|
||||
f"Warning: document file {filename} embedding error, skip it, error message: {error_msg}"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
embedding_engine.index_to_store(docs)
|
||||
print(f"""begin create {self.vector_store_config["vector_store_name"]} space""")
|
||||
try:
|
||||
@ -64,11 +74,24 @@ class LocalKnowledgeInit:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vector_name", type=str, default="default")
|
||||
# TODO https://github.com/csunny/DB-GPT/issues/354
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vector_name", type=str, default="default", help="Your vector store name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file_path", type=str, default=DATASETS_DIR, help="Your document path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_wrong_doc", type=bool, default=False, help="Skip wrong document"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
vector_name = args.vector_name
|
||||
store_type = CFG.VECTOR_STORE_TYPE
|
||||
file_path = args.file_path
|
||||
skip_wrong_doc = args.skip_wrong_doc
|
||||
vector_store_config = {
|
||||
"vector_store_name": vector_name,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
@ -76,5 +99,5 @@ if __name__ == "__main__":
|
||||
}
|
||||
print(vector_store_config)
|
||||
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
|
||||
kv.knowledge_persist(file_path=DATASETS_DIR)
|
||||
kv.knowledge_persist(file_path=file_path, skip_wrong_doc=skip_wrong_doc)
|
||||
print("your knowledge embedding success...")
|
||||
|
Loading…
Reference in New Issue
Block a user