diff --git a/tools/cli/cli_scripts.py b/tools/cli/cli_scripts.py new file mode 100644 index 000000000..545acb4a5 --- /dev/null +++ b/tools/cli/cli_scripts.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +DB-GPT command line tools. + +You can use it for some background management: +- Lots of knowledge document initialization. +- Load the data into the database. +- Show server status +- ... + + +Maybe move this to pilot module and append to console_scripts in the future. + +""" +import sys +import click +import os + +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) +) + + +from pilot.configs.model_config import DATASETS_DIR + +from tools.cli.knowledge_client import knowledge_init + +API_ADDRESS: str = "http://127.0.0.1:5000" + + +@click.group() +@click.option( + "--api_address", + required=False, + default="http://127.0.0.1:5000", + type=str, + help="Api server address", +) +@click.version_option() +def cli(api_address: str): + global API_ADDRESS + API_ADDRESS = api_address + + +@cli.command() +@click.option( + "--vector_name", + required=False, + type=str, + default="default", + help="Your vector store name", +) +@click.option( + "--vector_store_type", + required=False, + type=str, + default="Chroma", + help="Vector store type", +) +@click.option( + "--local_doc_dir", + required=False, + type=str, + default=DATASETS_DIR, + help="Your document directory", +) +@click.option( + "--skip_wrong_doc", + required=False, + type=bool, + default=False, + help="Skip wrong document", +) +@click.option( + "--max_workers", + required=False, + type=int, + default=None, + help="The maximum number of threads that can be used to upload document", +) +@click.option( + "-v", + "--verbose", + required=False, + is_flag=True, + hidden=True, + help="Show debuggging information.", +) +def knowledge( + vector_name: str, + vector_store_type: str, + local_doc_dir: str, + skip_wrong_doc: bool, + max_workers: int, + verbose: bool, +): + """Knowledge command line tool""" + knowledge_init( + API_ADDRESS, + vector_name, + vector_store_type, + local_doc_dir, + skip_wrong_doc, + verbose, + max_workers, + ) + + +# knowledge command +cli.add_command(knowledge) +# TODO add more command + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/tools/cli/knowledge_client.py b/tools/cli/knowledge_client.py new file mode 100644 index 000000000..282e1a28a --- /dev/null +++ b/tools/cli/knowledge_client.py @@ -0,0 +1,147 @@ +import os +import requests +import json + +from urllib.parse import urljoin +from concurrent.futures import ThreadPoolExecutor, as_completed + +from pilot.openapi.api_v1.api_view_model import Result +from pilot.server.knowledge.request.request import ( + KnowledgeQueryRequest, + KnowledgeDocumentRequest, + DocumentSyncRequest, + ChunkQueryRequest, + DocumentQueryRequest, +) + +from pilot.embedding_engine.knowledge_type import KnowledgeType +from pilot.server.knowledge.request.request import DocumentSyncRequest + +from pilot.server.knowledge.request.request import KnowledgeSpaceRequest + + +HTTP_HEADERS = {"Content-Type": "application/json"} + + +class ApiClient: + def __init__(self, api_address: str) -> None: + self.api_address = api_address + + def _handle_response(self, response): + if 200 <= response.status_code <= 300: + result = Result(**response.json()) + if not result.success: + raise Exception(result.err_msg) + return result.data + else: + raise Exception( + f"Http request error, code: {response.status_code}, message: {response.text}" + ) + + def _post(self, url: str, data=None): + if not isinstance(data, dict): + data = data.__dict__ + response = requests.post( + urljoin(self.api_address, url), data=json.dumps(data), headers=HTTP_HEADERS + ) + return self._handle_response(response) + + +class KnowledgeApiClient(ApiClient): + def __init__(self, api_address: str) -> None: + super().__init__(api_address) + + def space_add(self, request: KnowledgeSpaceRequest): + try: + return self._post("/knowledge/space/add", data=request) + except Exception as e: + if "have already named" in str(e): + print(f"Warning: you have already named {request.name}") + else: + raise e + + def space_list(self, request: KnowledgeSpaceRequest): + return self._post("/knowledge/space/list", data=request) + + def document_add(self, space_name: str, request: KnowledgeDocumentRequest): + url = f"/knowledge/{space_name}/document/add" + return self._post(url, data=request) + + def document_list(self, space_name: str, query_request: DocumentQueryRequest): + url = f"/knowledge/{space_name}/document/list" + return self._post(url, data=query_request) + + def document_upload(self, space_name, doc_name, doc_type, doc_file_path): + """Upload with multipart/form-data""" + url = f"{self.api_address}/knowledge/{space_name}/document/upload" + with open(doc_file_path, "rb") as f: + files = {"doc_file": f} + data = {"doc_name": doc_name, "doc_type": doc_type} + response = requests.post(url, data=data, files=files) + return self._handle_response(response) + + def document_sync(self, space_name: str, request: DocumentSyncRequest): + url = f"/knowledge/{space_name}/document/sync" + return self._post(url, data=request) + + def chunk_list(self, space_name: str, query_request: ChunkQueryRequest): + url = f"/knowledge/{space_name}/chunk/list" + return self._post(url, data=query_request) + + def similar_query(self, vector_name: str, query_request: KnowledgeQueryRequest): + url = f"/knowledge/{vector_name}/query" + return self._post(url, data=query_request) + + +def knowledge_init( + api_address: str, + vector_name: str, + vector_store_type: str, + local_doc_dir: str, + skip_wrong_doc: bool, + verbose: bool, + max_workers: int = None, +): + client = KnowledgeApiClient(api_address) + space = KnowledgeSpaceRequest() + space.name = vector_name + space.desc = "DB-GPT cli" + space.vector_type = vector_store_type + space.owner = "DB-GPT" + + # Create space + print(f"Create space: {space}") + client.space_add(space) + print("Create space successfully") + space_list = client.space_list(KnowledgeSpaceRequest(name=space.name)) + if len(space_list) != 1: + raise Exception(f"List space {space.name} error") + space = KnowledgeSpaceRequest(**space_list[0]) + + doc_ids = [] + + def upload(filename: str): + try: + print(f"Begin upload document: {filename} to {space.name}") + return client.document_upload( + space.name, filename, KnowledgeType.DOCUMENT.value, filename + ) + except Exception as e: + if skip_wrong_doc: + print(f"Warning: {str(e)}") + else: + raise e + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + tasks = [] + for root, _, files in os.walk(local_doc_dir, topdown=False): + for file in files: + filename = os.path.join(root, file) + tasks.append(pool.submit(upload, filename)) + doc_ids = [r.result() for r in as_completed(tasks)] + doc_ids = list(filter(lambda x: x, doc_ids)) + if not doc_ids: + print("Warning: no document to sync") + return + print(f"Begin sync document: {doc_ids}") + client.document_sync(space.name, DocumentSyncRequest(doc_ids=doc_ids))