feat: Command-line tool with knowledge repository initialization

This commit is contained in:
FangYin Cheng
2023-09-01 18:21:22 +08:00
parent d42afb50a7
commit e5bbd0bd86
9 changed files with 153 additions and 257 deletions

View File

@@ -1,122 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DB-GPT command line tools.
You can use it for some background management:
- Lots of knowledge document initialization.
- Load the data into the database.
- Show server status
- ...
Maybe move this to pilot module and append to console_scripts in the future.
"""
import sys
import click
import os
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
from pilot.configs.model_config import DATASETS_DIR
API_ADDRESS: str = "http://127.0.0.1:5000"
@click.group()
@click.option(
"--api_address",
required=False,
default="http://127.0.0.1:5000",
type=str,
help="Api server address",
)
@click.version_option()
def cli(api_address: str):
global API_ADDRESS
API_ADDRESS = api_address
@cli.command()
@click.option(
"--vector_name",
required=False,
type=str,
default="default",
help="Your vector store name",
)
@click.option(
"--vector_store_type",
required=False,
type=str,
default="Chroma",
help="Vector store type",
)
@click.option(
"--local_doc_dir",
required=False,
type=str,
default=DATASETS_DIR,
help="Your document directory",
)
@click.option(
"--skip_wrong_doc",
required=False,
type=bool,
default=False,
help="Skip wrong document",
)
@click.option(
"--max_workers",
required=False,
type=int,
default=None,
help="The maximum number of threads that can be used to upload document",
)
@click.option(
"-v",
"--verbose",
required=False,
is_flag=True,
hidden=True,
help="Show debuggging information.",
)
def knowledge(
vector_name: str,
vector_store_type: str,
local_doc_dir: str,
skip_wrong_doc: bool,
max_workers: int,
verbose: bool,
):
"""Knowledge command line tool"""
from tools.cli.knowledge_client import knowledge_init
knowledge_init(
API_ADDRESS,
vector_name,
vector_store_type,
local_doc_dir,
skip_wrong_doc,
verbose,
max_workers,
)
# knowledge command
cli.add_command(knowledge)
# TODO add more command
def main():
return cli()
if __name__ == "__main__":
main()
raise Exception(
"The functionality of this script has been moved to the command line tool `dbgpt`. For details on usage, please execute the command `dbgpt --help`."
)

View File

@@ -1,146 +0,0 @@
import os
import requests
import json
from urllib.parse import urljoin
from concurrent.futures import ThreadPoolExecutor, as_completed
from pilot.openapi.api_view_model import Result
from pilot.server.knowledge.request.request import (
KnowledgeQueryRequest,
KnowledgeDocumentRequest,
ChunkQueryRequest,
DocumentQueryRequest,
)
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.server.knowledge.request.request import DocumentSyncRequest
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
HTTP_HEADERS = {"Content-Type": "application/json"}
class ApiClient:
def __init__(self, api_address: str) -> None:
self.api_address = api_address
def _handle_response(self, response):
if 200 <= response.status_code <= 300:
result = Result(**response.json())
if not result.success:
raise Exception(result.err_msg)
return result.data
else:
raise Exception(
f"Http request error, code: {response.status_code}, message: {response.text}"
)
def _post(self, url: str, data=None):
if not isinstance(data, dict):
data = data.__dict__
response = requests.post(
urljoin(self.api_address, url), data=json.dumps(data), headers=HTTP_HEADERS
)
return self._handle_response(response)
class KnowledgeApiClient(ApiClient):
def __init__(self, api_address: str) -> None:
super().__init__(api_address)
def space_add(self, request: KnowledgeSpaceRequest):
try:
return self._post("/knowledge/space/add", data=request)
except Exception as e:
if "have already named" in str(e):
print(f"Warning: you have already named {request.name}")
else:
raise e
def space_list(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/list", data=request)
def document_add(self, space_name: str, request: KnowledgeDocumentRequest):
url = f"/knowledge/{space_name}/document/add"
return self._post(url, data=request)
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
url = f"/knowledge/{space_name}/document/list"
return self._post(url, data=query_request)
def document_upload(self, space_name, doc_name, doc_type, doc_file_path):
"""Upload with multipart/form-data"""
url = f"{self.api_address}/knowledge/{space_name}/document/upload"
with open(doc_file_path, "rb") as f:
files = {"doc_file": f}
data = {"doc_name": doc_name, "doc_type": doc_type}
response = requests.post(url, data=data, files=files)
return self._handle_response(response)
def document_sync(self, space_name: str, request: DocumentSyncRequest):
url = f"/knowledge/{space_name}/document/sync"
return self._post(url, data=request)
def chunk_list(self, space_name: str, query_request: ChunkQueryRequest):
url = f"/knowledge/{space_name}/chunk/list"
return self._post(url, data=query_request)
def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest):
url = f"/knowledge/{vector_name}/query"
return self._post(url, data=query_request)
def knowledge_init(
api_address: str,
vector_name: str,
vector_store_type: str,
local_doc_dir: str,
skip_wrong_doc: bool,
verbose: bool,
max_workers: int = None,
):
client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest()
space.name = vector_name
space.desc = "DB-GPT cli"
space.vector_type = vector_store_type
space.owner = "DB-GPT"
# Create space
print(f"Create space: {space}")
client.space_add(space)
print("Create space successfully")
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
if len(space_list) != 1:
raise Exception(f"List space {space.name} error")
space = KnowledgeSpaceRequest(**space_list[0])
doc_ids = []
def upload(filename: str):
try:
print(f"Begin upload document: {filename} to {space.name}")
return client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
except Exception as e:
if skip_wrong_doc:
print(f"Warning: {str(e)}")
else:
raise e
with ThreadPoolExecutor(max_workers=max_workers) as pool:
tasks = []
for root, _, files in os.walk(local_doc_dir, topdown=False):
for file in files:
filename = os.path.join(root, file)
tasks.append(pool.submit(upload, filename))
doc_ids = [r.result() for r in as_completed(tasks)]
doc_ids = list(filter(lambda x: x, doc_ids))
if not doc_ids:
print("Warning: no document to sync")
return
print(f"Begin sync document: {doc_ids}")
client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))

View File

@@ -1,103 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import sys
import traceback
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.server.knowledge.service import KnowledgeService
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
LLM_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH,
)
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
knowledge_space_service = KnowledgeService()
CFG = Config()
class LocalKnowledgeInit:
embeddings: object = None
def __init__(self, vector_store_config) -> None:
self.vector_store_config = vector_store_config
self.model_name = LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
def knowledge_persist(self, file_path: str, skip_wrong_doc: bool = False):
"""knowledge persist"""
docs = []
embedding_engine = None
for root, _, files in os.walk(file_path, topdown=False):
for file in files:
filename = os.path.join(root, file)
ke = EmbeddingEngine(
knowledge_source=filename,
knowledge_type=KnowledgeType.DOCUMENT.value,
model_name=self.model_name,
vector_store_config=self.vector_store_config,
)
try:
embedding_engine = ke.init_knowledge_embedding()
doc = ke.read()
docs.extend(doc)
except Exception as e:
error_msg = traceback.format_exc()
if skip_wrong_doc:
print(
f"Warning: document file {filename} embedding error, skip it, error message: {error_msg}"
)
else:
raise e
embedding_engine.index_to_store(docs)
print(f"""begin create {self.vector_store_config["vector_store_name"]} space""")
try:
space = KnowledgeSpaceRequest
space.name = self.vector_store_config["vector_store_name"]
space.desc = "knowledge_init.py"
space.vector_type = CFG.VECTOR_STORE_TYPE
space.owner = "DB-GPT"
knowledge_space_service.create_knowledge_space(space)
except Exception as e:
if "have already named" in str(e):
print(f"Warning: you have already named {space.name}")
else:
raise e
if __name__ == "__main__":
# TODO https://github.com/csunny/DB-GPT/issues/354
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
raise Exception(
"The functionality of this script has been moved to the command line tool `dbgpt`. For details on usage, please execute the command `dbgpt --help`."
)
parser.add_argument(
"--vector_name", type=str, default="default", help="Your vector store name"
)
parser.add_argument(
"--file_path", type=str, default=DATASETS_DIR, help="Your document path"
)
parser.add_argument(
"--skip_wrong_doc", type=bool, default=False, help="Skip wrong document"
)
args = parser.parse_args()
vector_name = args.vector_name
store_type = CFG.VECTOR_STORE_TYPE
file_path = args.file_path
skip_wrong_doc = args.skip_wrong_doc
vector_store_config = {
"vector_store_name": vector_name,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
}
print(vector_store_config)
kv = LocalKnowledgeInit(vector_store_config=vector_store_config)
kv.knowledge_persist(file_path=file_path, skip_wrong_doc=skip_wrong_doc)
print("your knowledge embedding success...")