diff --git a/docs/modules/knownledge.md b/docs/modules/knownledge.md index c108920b2..cd42922ed 100644 --- a/docs/modules/knownledge.md +++ b/docs/modules/knownledge.md @@ -31,6 +31,8 @@ python tools/knowledge_init.py ``` +Optionally, you can run `python tools/knowledge_init.py -h` command to see more usage. + 3.Add the knowledge repository in the interface by entering the name of your knowledge repository (if not specified, enter "default") so you can use it for Q&A based on your knowledge base. Note that the default vector model used is text2vec-large-chinese (which is a large model, so if your personal computer configuration is not enough, it is recommended to use text2vec-base-chinese). Therefore, ensure that you download the model and place it in the models directory. \ No newline at end of file diff --git a/tools/cli/cli_scripts.py b/tools/cli/cli_scripts.py new file mode 100644 index 000000000..545acb4a5 --- /dev/null +++ b/tools/cli/cli_scripts.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +DB-GPT command line tools. + +You can use it for some background management: +- Lots of knowledge document initialization. +- Load the data into the database. +- Show server status +- ... + + +Maybe move this to pilot module and append to console_scripts in the future. + +""" +import sys +import click +import os + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +) + + +from pilot.configs.model_config import DATASETS_DIR + +from tools.cli.knowledge_client import knowledge_init + +API_ADDRESS: str = "http://127.0.0.1:5000" + + +@click.group() +@click.option( + "--api_address", + required=False, + default="http://127.0.0.1:5000", + type=str, + help="Api server address", +) +@click.version_option() +def cli(api_address: str): + global API_ADDRESS + API_ADDRESS = api_address + + +@cli.command() +@click.option( + "--vector_name", + required=False, + type=str, + default="default", + help="Your vector store name", +) +@click.option( + "--vector_store_type", + required=False, + type=str, + default="Chroma", + help="Vector store type", +) +@click.option( + "--local_doc_dir", + required=False, + type=str, + default=DATASETS_DIR, + help="Your document directory", +) +@click.option( + "--skip_wrong_doc", + required=False, + type=bool, + default=False, + help="Skip wrong document", +) +@click.option( + "--max_workers", + required=False, + type=int, + default=None, + help="The maximum number of threads that can be used to upload document", +) +@click.option( + "-v", + "--verbose", + required=False, + is_flag=True, + hidden=True, + help="Show debuggging information.", +) +def knowledge( + vector_name: str, + vector_store_type: str, + local_doc_dir: str, + skip_wrong_doc: bool, + max_workers: int, + verbose: bool, +): + """Knowledge command line tool""" + knowledge_init( + API_ADDRESS, + vector_name, + vector_store_type, + local_doc_dir, + skip_wrong_doc, + verbose, + max_workers, + ) + + +# knowledge command +cli.add_command(knowledge) +# TODO add more command + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/tools/cli/knowledge_client.py b/tools/cli/knowledge_client.py new file mode 100644 index 000000000..282e1a28a --- /dev/null +++ b/tools/cli/knowledge_client.py @@ -0,0 +1,147 @@ +import os +import requests +import json + +from urllib.parse import urljoin +from concurrent.futures import ThreadPoolExecutor, as_completed + +from pilot.openapi.api_v1.api_view_model import Result +from pilot.server.knowledge.request.request import ( + KnowledgeQueryRequest, + KnowledgeDocumentRequest, + DocumentSyncRequest, + ChunkQueryRequest, + DocumentQueryRequest, +) + +from pilot.embedding_engine.knowledge_type import KnowledgeType +from pilot.server.knowledge.request.request import DocumentSyncRequest + +from pilot.server.knowledge.request.request import KnowledgeSpaceRequest + + +HTTP_HEADERS = {"Content-Type": "application/json"} + + +class ApiClient: + def __init__(self, api_address: str) -> None: + self.api_address = api_address + + def _handle_response(self, response): + if 200 <= response.status_code <= 300: + result = Result(**response.json()) + if not result.success: + raise Exception(result.err_msg) + return result.data + else: + raise Exception( + f"Http request error, code: {response.status_code}, message: {response.text}" + ) + + def _post(self, url: str, data=None): + if not isinstance(data, dict): + data = data.__dict__ + response = requests.post( + urljoin(self.api_address, url), data=json.dumps(data), headers=HTTP_HEADERS + ) + return self._handle_response(response) + + +class KnowledgeApiClient(ApiClient): + def __init__(self, api_address: str) -> None: + super().__init__(api_address) + + def space_add(self, request: KnowledgeSpaceRequest): + try: + return self._post("/knowledge/space/add", data=request) + except Exception as e: + if "have already named" in str(e): + print(f"Warning: you have already named {request.name}") + else: + raise e + + def space_list(self, request: KnowledgeSpaceRequest): + return self._post("/knowledge/space/list", data=request) + + def document_add(self, space_name: str, request: KnowledgeDocumentRequest): + url = f"/knowledge/{space_name}/document/add" + return self._post(url, data=request) + + def document_list(self, space_name: str, query_request: DocumentQueryRequest): + url = f"/knowledge/{space_name}/document/list" + return self._post(url, data=query_request) + + def document_upload(self, space_name, doc_name, doc_type, doc_file_path): + """Upload with multipart/form-data""" + url = f"{self.api_address}/knowledge/{space_name}/document/upload" + with open(doc_file_path, "rb") as f: + files = {"doc_file": f} + data = {"doc_name": doc_name, "doc_type": doc_type} + response = requests.post(url, data=data, files=files) + return self._handle_response(response) + + def document_sync(self, space_name: str, request: DocumentSyncRequest): + url = f"/knowledge/{space_name}/document/sync" + return self._post(url, data=request) + + def chunk_list(self, space_name: str, query_request: ChunkQueryRequest): + url = f"/knowledge/{space_name}/chunk/list" + return self._post(url, data=query_request) + + def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest): + url = f"/knowledge/{vector_name}/query" + return self._post(url, data=query_request) + + +def knowledge_init( + api_address: str, + vector_name: str, + vector_store_type: str, + local_doc_dir: str, + skip_wrong_doc: bool, + verbose: bool, + max_workers: int = None, +): + client = KnowledgeApiClient(api_address) + space = KnowledgeSpaceRequest() + space.name = vector_name + space.desc = "DB-GPT cli" + space.vector_type = vector_store_type + space.owner = "DB-GPT" + + # Create space + print(f"Create space: {space}") + client.space_add(space) + print("Create space successfully") + space_list = client.space_list(KnowledgeSpaceRequest(name=space.name)) + if len(space_list) != 1: + raise Exception(f"List space {space.name} error") + space = KnowledgeSpaceRequest(**space_list[0]) + + doc_ids = [] + + def upload(filename: str): + try: + print(f"Begin upload document: {filename} to {space.name}") + return client.document_upload( + space.name, filename, KnowledgeType.DOCUMENT.value, filename + ) + except Exception as e: + if skip_wrong_doc: + print(f"Warning: {str(e)}") + else: + raise e + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + tasks = [] + for root, _, files in os.walk(local_doc_dir, topdown=False): + for file in files: + filename = os.path.join(root, file) + tasks.append(pool.submit(upload, filename)) + doc_ids = [r.result() for r in as_completed(tasks)] + doc_ids = list(filter(lambda x: x, doc_ids)) + if not doc_ids: + print("Warning: no document to sync") + return + print(f"Begin sync document: {doc_ids}") + client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids)) diff --git a/tools/knowledge_init.py b/tools/knowledge_init.py index c442de8c9..aca2baf1d 100644 --- a/tools/knowledge_init.py +++ b/tools/knowledge_init.py @@ -3,6 +3,7 @@ import argparse import os import sys +import traceback sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) @@ -31,7 +32,7 @@ class LocalKnowledgeInit: self.vector_store_config = vector_store_config self.model_name = LLM_MODEL_CONFIG[CFG.EMBEDDING_MODEL] - def knowledge_persist(self, file_path): + def knowledge_persist(self, file_path: str, skip_wrong_doc: bool = False): """knowledge persist""" docs = [] embedding_engine = None @@ -44,9 +45,18 @@ class LocalKnowledgeInit: model_name=self.model_name, vector_store_config=self.vector_store_config, ) - embedding_engine = ke.init_knowledge_embedding() - doc = ke.read() - docs.extend(doc) + try: + embedding_engine = ke.init_knowledge_embedding() + doc = ke.read() + docs.extend(doc) + except Exception as e: + error_msg = traceback.format_exc() + if skip_wrong_doc: + print( + f"Warning: document file {filename} embedding error, skip it, error message: {error_msg}" + ) + else: + raise e embedding_engine.index_to_store(docs) print(f"""begin create {self.vector_store_config["vector_store_name"]} space""") try: @@ -64,11 +74,24 @@ class LocalKnowledgeInit: if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--vector_name", type=str, default="default") + # TODO https://github.com/csunny/DB-GPT/issues/354 + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--vector_name", type=str, default="default", help="Your vector store name" + ) + parser.add_argument( + "--file_path", type=str, default=DATASETS_DIR, help="Your document path" + ) + parser.add_argument( + "--skip_wrong_doc", type=bool, default=False, help="Skip wrong document" + ) args = parser.parse_args() vector_name = args.vector_name store_type = CFG.VECTOR_STORE_TYPE + file_path = args.file_path + skip_wrong_doc = args.skip_wrong_doc vector_store_config = { "vector_store_name": vector_name, "vector_store_type": CFG.VECTOR_STORE_TYPE, @@ -76,5 +99,5 @@ if __name__ == "__main__": } print(vector_store_config) kv = LocalKnowledgeInit(vector_store_config=vector_store_config) - kv.knowledge_persist(file_path=DATASETS_DIR) + kv.knowledge_persist(file_path=file_path, skip_wrong_doc=skip_wrong_doc) print("your knowledge embedding success...")