diff --git a/.env.template b/.env.template index 47b7c3b99..65fbb1391 100644 --- a/.env.template +++ b/.env.template @@ -55,6 +55,8 @@ EMBEDDING_MODEL=text2vec #EMBEDDING_MODEL=bge-large-zh KNOWLEDGE_CHUNK_SIZE=500 KNOWLEDGE_SEARCH_TOP_SIZE=5 +# Control whether to display the source document of knowledge on the front end. +KNOWLEDGE_CHAT_SHOW_RELATIONS=False ## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs ## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs # EMBEDDING_MODEL=all-MiniLM-L6-v2 @@ -154,4 +156,6 @@ SUMMARY_CONFIG=FAST #** LOG **# #*******************************************************************# # FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET -DBGPT_LOG_LEVEL=INFO \ No newline at end of file +DBGPT_LOG_LEVEL=INFO +# LOG dir, default: ./logs +#DBGPT_LOG_DIR= \ No newline at end of file diff --git a/pilot/agent/json_fix_llm.py b/pilot/agent/json_fix_llm.py index aa7c4cec8..ed551ffda 100644 --- a/pilot/agent/json_fix_llm.py +++ b/pilot/agent/json_fix_llm.py @@ -1,6 +1,7 @@ import contextlib import json from typing import Any, Dict +import logging from colorama import Fore from regex import regex @@ -11,9 +12,9 @@ from pilot.json_utils.json_fix_general import ( balance_braces, fix_invalid_escape, ) -from pilot.logs import logger +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/commands/built_in/image_gen.py b/pilot/commands/built_in/image_gen.py index d6492e2d9..5e46c6c3a 100644 --- a/pilot/commands/built_in/image_gen.py +++ b/pilot/commands/built_in/image_gen.py @@ -2,14 +2,15 @@ import io import uuid from base64 import b64decode +import logging import requests from PIL import Image from pilot.commands.command_mange import command from pilot.configs.config import Config -from pilot.logs import logger +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index 097685369..e3bdc77e1 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -1,5 +1,5 @@ +import logging from pandas import DataFrame - from pilot.commands.command_mange import command from pilot.configs.config import Config import pandas as pd @@ -13,11 +13,10 @@ import matplotlib.pyplot as plt import matplotlib.ticker as mtick from matplotlib.font_manager import FontManager -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger - CFG = Config() -logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") + +logger = logging.getLogger(__name__) + static_message_img_path = os.path.join(os.getcwd(), "message/img") diff --git a/pilot/commands/disply_type/show_table_gen.py b/pilot/commands/disply_type/show_table_gen.py index d67e39eb9..3edc193d2 100644 --- a/pilot/commands/disply_type/show_table_gen.py +++ b/pilot/commands/disply_type/show_table_gen.py @@ -1,14 +1,13 @@ +import logging + import pandas as pd from pandas import DataFrame from pilot.commands.command_mange import command from pilot.configs.config import Config -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger - CFG = Config() -logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") +logger = logging.getLogger(__name__) @command( diff --git a/pilot/commands/disply_type/show_text_gen.py b/pilot/commands/disply_type/show_text_gen.py index 16d9fbe91..8ee98fd7e 100644 --- a/pilot/commands/disply_type/show_text_gen.py +++ b/pilot/commands/disply_type/show_text_gen.py @@ -1,13 +1,12 @@ +import logging import pandas as pd from pandas import DataFrame from pilot.commands.command_mange import command from pilot.configs.config import Config -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger CFG = Config() -logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") +logger = logging.getLogger(__name__) @command( diff --git a/pilot/common/plugins.py b/pilot/common/plugins.py index 517dc800a..1f63bc3be 100644 --- a/pilot/common/plugins.py +++ b/pilot/common/plugins.py @@ -8,6 +8,7 @@ import zipfile import requests import threading import datetime +import logging from pathlib import Path from typing import List, TYPE_CHECKING from urllib.parse import urlparse @@ -17,7 +18,8 @@ import requests from pilot.configs.config import Config from pilot.configs.model_config import PLUGINS_DIR -from pilot.logs import logger + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from auto_gpt_plugin_template import AutoGPTPluginTemplate diff --git a/pilot/configs/config.py b/pilot/configs/config.py index b2af3ec3a..134bd66f6 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -32,7 +32,7 @@ class Config(metaclass=Singleton): # self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1)) self.execute_local_commands = ( - os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" + os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true" ) # User agent header to use when making HTTP requests # Some websites might just completely deny request with an error code if @@ -64,7 +64,7 @@ class Config(metaclass=Singleton): self.milvus_username = os.getenv("MILVUS_USERNAME") self.milvus_password = os.getenv("MILVUS_PASSWORD") self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") - self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" + self.milvus_secure = os.getenv("MILVUS_SECURE", "False").lower() == "true" self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.exit_key = os.getenv("EXIT_KEY", "n") @@ -98,7 +98,7 @@ class Config(metaclass=Singleton): self.disabled_command_categories = [] self.execute_local_commands = ( - os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" + os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true" ) ### message stor file self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") @@ -107,7 +107,7 @@ class Config(metaclass=Singleton): self.plugins: List["AutoGPTPluginTemplate"] = [] self.plugins_openai = [] - self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" + self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true" self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") @@ -124,10 +124,10 @@ class Config(metaclass=Singleton): self.plugins_denylist = [] ### Native SQL Execution Capability Control Configuration self.NATIVE_SQL_CAN_RUN_DDL = ( - os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True" + os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true" ) self.NATIVE_SQL_CAN_RUN_WRITE = ( - os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True" + os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true" ) ### default Local database connection configuration @@ -170,8 +170,8 @@ class Config(metaclass=Singleton): # QLoRA self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") - self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True" - self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True" + self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true" + self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true" if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT: self.IS_LOAD_8BIT = False # In order to be compatible with the new and old model parameter design @@ -187,7 +187,9 @@ class Config(metaclass=Singleton): os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000) ) ### Control whether to display the source document of knowledge on the front end. - self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False + self.KNOWLEDGE_CHAT_SHOW_RELATIONS = ( + os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true" + ) ### SUMMARY_CONFIG Configuration self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index c70a0912e..216a6f03f 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -3,13 +3,12 @@ import os -# import nltk - ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) MODEL_PATH = os.path.join(ROOT_PATH, "models") PILOT_PATH = os.path.join(ROOT_PATH, "pilot") VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store") -LOGDIR = os.path.join(ROOT_PATH, "logs") +LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs")) + DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATA_DIR = os.path.join(PILOT_PATH, "data") # nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path diff --git a/pilot/connections/rdbms/base_dao.py b/pilot/connections/rdbms/base_dao.py index 82edb404f..e6a5be983 100644 --- a/pilot/connections/rdbms/base_dao.py +++ b/pilot/connections/rdbms/base_dao.py @@ -1,10 +1,11 @@ +import logging from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from pilot.configs.config import Config from pilot.common.schema import DBType from pilot.connections.rdbms.base import RDBMSDatabase -from pilot.logs import logger +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/embedding_engine/__init__.py b/pilot/embedding_engine/__init__.py index c12543a1f..3fa28f194 100644 --- a/pilot/embedding_engine/__init__.py +++ b/pilot/embedding_engine/__init__.py @@ -1,5 +1,12 @@ from pilot.embedding_engine.source_embedding import SourceEmbedding, register from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.knowledge_type import KnowledgeType +from pilot.embedding_engine.pre_text_splitter import PreTextSplitter -__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"] +__all__ = [ + "SourceEmbedding", + "register", + "EmbeddingEngine", + "KnowledgeType", + "PreTextSplitter", +] diff --git a/pilot/embedding_engine/pre_text_splitter.py b/pilot/embedding_engine/pre_text_splitter.py new file mode 100644 index 000000000..cf7cdfdf5 --- /dev/null +++ b/pilot/embedding_engine/pre_text_splitter.py @@ -0,0 +1,30 @@ +from typing import Iterable, List +from langchain.schema import Document +from langchain.text_splitter import TextSplitter + + +def _single_document_split( + document: Document, pre_separator: str +) -> Iterable[Document]: + page_content = document.page_content + for i, content in enumerate(page_content.split(pre_separator)): + metadata = document.metadata.copy() + if "source" in metadata: + metadata["source"] = metadata["source"] + "_pre_split_" + str(i) + yield Document(page_content=content, metadata=metadata) + + +class PreTextSplitter(TextSplitter): + def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter): + self.pre_separator = pre_separator + self._impl = text_splitter_impl + + def split_text(self, text: str) -> List[str]: + return self._impl.split_text(text) + + def split_documents(self, documents: Iterable[Document]) -> List[Document]: + def generator() -> Iterable[Document]: + for doc in documents: + yield from _single_document_split(doc, pre_separator=self.pre_separator) + + return self._impl.split_documents(generator()) diff --git a/pilot/json_utils/json_fix_general.py b/pilot/json_utils/json_fix_general.py index e24d02bbf..ea9ee9b5b 100644 --- a/pilot/json_utils/json_fix_general.py +++ b/pilot/json_utils/json_fix_general.py @@ -5,12 +5,13 @@ from __future__ import annotations import contextlib import json import re +import logging from typing import Optional from pilot.configs.config import Config from pilot.json_utils.utilities import extract_char_position -from pilot.logs import logger +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/json_utils/utilities.py b/pilot/json_utils/utilities.py index 9eb753912..452a8a536 100644 --- a/pilot/json_utils/utilities.py +++ b/pilot/json_utils/utilities.py @@ -3,12 +3,14 @@ import json import os.path import re import json +import logging from datetime import datetime from jsonschema import Draft7Validator from pilot.configs.config import Config -from pilot.logs import logger + +logger = logging.getLogger(__name__) CFG = Config() LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1" diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index c29427b1d..760e8722d 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,6 +3,7 @@ import os import re +import logging from pathlib import Path from typing import List, Tuple, Callable, Type from functools import cache @@ -19,8 +20,8 @@ from pilot.model.parameter import ( ) from pilot.configs.model_config import get_device from pilot.configs.config import Config -from pilot.logs import logger +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/model/cluster/controller/controller.py b/pilot/model/cluster/controller/controller.py index e93216929..826ffef03 100644 --- a/pilot/model/cluster/controller/controller.py +++ b/pilot/model/cluster/controller/controller.py @@ -13,6 +13,9 @@ from pilot.utils.api_utils import ( _api_remote as api_remote, _sync_api_remote as sync_api_remote, ) +from pilot.utils.utils import setup_logging + +logger = logging.getLogger(__name__) class BaseModelController(BaseComponent, ABC): @@ -59,7 +62,7 @@ class LocalModelController(BaseModelController): async def get_all_instances( self, model_name: str = None, healthy_only: bool = False ) -> List[ModelInstance]: - logging.info( + logger.info( f"Get all instances with {model_name}, healthy_only: {healthy_only}" ) if not model_name: @@ -178,6 +181,13 @@ def run_model_controller(): controller_params: ModelControllerParameters = parser.parse_args_into_dataclass( ModelControllerParameters, env_prefix=env_prefix ) + + setup_logging( + "pilot", + logging_level=controller_params.log_level, + logger_filename="dbgpt_model_controller.log", + ) + initialize_controller(host=controller_params.host, port=controller_params.port) diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index e72e9ba31..c210fcb44 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -120,7 +120,9 @@ class DefaultModelWorker(ModelWorker): text=output, error_code=0, model_context=model_context ) yield model_output - print(f"\n\nfull stream output:\n{previous_response}") + print( + f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" + ) except Exception as e: # Check if the exception is a torch.cuda.CudaError and if torch was imported. if torch_imported and isinstance(e, torch.cuda.CudaError): diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index da34c314a..72d9c32d4 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -36,6 +36,7 @@ from pilot.utils.parameter_utils import ( ParameterDescription, _dict_to_command_args, ) +from pilot.utils.utils import setup_logging logger = logging.getLogger(__name__) @@ -885,6 +886,12 @@ def run_worker_manager( model_name=model_name, model_path=model_path, standalone=standalone, port=port ) + setup_logging( + "pilot", + logging_level=worker_params.log_level, + logger_filename="dbgpt_model_worker_manager.log", + ) + embedded_mod = True logger.info(f"Worker params: {worker_params}") if not app: diff --git a/pilot/model/llm/llama_cpp/llama_cpp.py b/pilot/model/llm/llama_cpp/llama_cpp.py index 5b7431a9e..be8292a43 100644 --- a/pilot/model/llm/llama_cpp/llama_cpp.py +++ b/pilot/model/llm/llama_cpp/llama_cpp.py @@ -3,11 +3,13 @@ Fork from text-generation-webui https://github.com/oobabooga/text-generation-web """ import re from typing import Dict +import logging import torch import llama_cpp from pilot.model.parameter import LlamaCppModelParameters -from pilot.logs import logger + +logger = logging.getLogger(__name__) if torch.cuda.is_available() and not torch.version.hip: try: diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 63a484151..39adf24ad 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -3,6 +3,7 @@ from typing import Optional, Dict +import logging from pilot.configs.model_config import get_device from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.parameter import ( @@ -12,7 +13,8 @@ from pilot.model.parameter import ( ) from pilot.utils import get_gpu_memory from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case -from pilot.logs import logger + +logger = logging.getLogger(__name__) def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 6fd8ec5ab..6a92870b5 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -31,6 +31,21 @@ class ModelControllerParameters(BaseParameters): daemon: Optional[bool] = field( default=False, metadata={"help": "Run Model Controller in background"} ) + log_level: Optional[str] = field( + default=None, + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) @dataclass @@ -85,6 +100,22 @@ class ModelWorkerParameters(BaseModelParameters): default=20, metadata={"help": "The interval for sending heartbeats (seconds)"} ) + log_level: Optional[str] = field( + default=None, + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) + @dataclass class BaseEmbeddingModelParameters(BaseModelParameters): diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 24bee6cdb..ec71f50ac 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -38,8 +38,6 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.scene.chat_factory import ChatFactory -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger from pilot.common.schema import DBType from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.scene.message import OnceConversation @@ -409,7 +407,7 @@ async def model_types(controller: BaseModelController = Depends(get_model_contro @router.get("/v1/model/supports") -async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)): +async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)): logger.info(f"/controller/model/supports") try: models = await worker_manager.supported_models() diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py index e1b313664..7d8dce420 100644 --- a/pilot/openapi/api_v1/editor/api_editor_v1.py +++ b/pilot/openapi/api_v1/editor/api_editor_v1.py @@ -6,12 +6,11 @@ from fastapi import ( ) from typing import List +import logging from pilot.configs.config import Config from pilot.scene.chat_factory import ChatFactory -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger from pilot.openapi.api_view_model import ( Result, @@ -34,7 +33,8 @@ from pilot.scene.chat_db.data_loader import DbDataLoader router = APIRouter() CFG = Config() CHAT_FACTORY = ChatFactory() -logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log") + +logger = logging.getLogger(__name__) @router.get("/v1/editor/db/tables", response_model=Result[DbTable]) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index abd4d304b..78286f3d3 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -2,18 +2,17 @@ from __future__ import annotations import json from abc import ABC +import logging from dataclasses import asdict from typing import Any, Dict, TypeVar, Union from pilot.configs.config import Config -from pilot.configs.model_config import LOGDIR from pilot.model.base import ModelOutput -from pilot.utils import build_logger T = TypeVar("T") ResponseTye = Union[str, bytes, ModelOutput] -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) CFG = Config() diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 70805ca80..daab56964 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -1,11 +1,11 @@ import datetime import traceback import warnings +import logging from abc import ABC, abstractmethod from typing import Any, List, Dict from pilot.configs.config import Config -from pilot.configs.model_config import LOGDIR from pilot.component import ComponentType from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory @@ -14,10 +14,10 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.prompts.prompt_new import PromptTemplate from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.message import OnceConversation -from pilot.utils import build_logger, get_or_create_event_loop +from pilot.utils import get_or_create_event_loop from pydantic import Extra -logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") +logger = logging.getLogger(__name__) headers = {"User-Agent": "dbgpt Client"} CFG = Config() diff --git a/pilot/scene/chat_dashboard/data_loader.py b/pilot/scene/chat_dashboard/data_loader.py index 945bf76cb..faabe542a 100644 --- a/pilot/scene/chat_dashboard/data_loader.py +++ b/pilot/scene/chat_dashboard/data_loader.py @@ -1,13 +1,12 @@ from typing import List from decimal import Decimal +import logging from pilot.configs.config import Config -from pilot.configs.model_config import LOGDIR -from pilot.utils import build_logger from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem CFG = Config() -logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log") +logger = logging.getLogger(__name__) class DashboardDataLoader: diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index bf0fedef4..d593333e5 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -1,8 +1,8 @@ import json +import logging + from typing import NamedTuple, List -from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR from pilot.scene.base import ChatScene @@ -13,7 +13,7 @@ class ChartItem(NamedTuple): showcase: str -logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log") +logger = logging.getLogger(__name__) class ChatDashboardOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_analyze/out_parser.py index f627a80c1..691c652fa 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/out_parser.py @@ -1,11 +1,7 @@ import json -import re -from abc import ABC, abstractmethod +import logging from typing import Dict, NamedTuple, List -import pandas as pd -from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config CFG = Config() @@ -17,7 +13,7 @@ class ExcelAnalyzeResponse(NamedTuple): display: str -logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log") +logger = logging.getLogger(__name__) class ChatExcelOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py index 0d636c16d..e3bcfb8ea 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -1,11 +1,7 @@ import json -import re -from abc import ABC, abstractmethod +import logging from typing import Dict, NamedTuple, List -import pandas as pd -from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config CFG = Config() @@ -17,7 +13,7 @@ class ExcelResponse(NamedTuple): plans: List -logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log") +logger = logging.getLogger(__name__) class LearningExcelOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 8a79bca33..03c71b431 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -1,8 +1,7 @@ import json from typing import Dict, NamedTuple -from pilot.utils import build_logger +import logging from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config from pilot.scene.chat_db.data_loader import DbDataLoader @@ -14,7 +13,7 @@ class SqlAction(NamedTuple): thoughts: Dict -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) class DbChatOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_db/professional_qa/out_parser.py b/pilot/scene/chat_db/professional_qa/out_parser.py index 0c15aedbe..73d61e349 100644 --- a/pilot/scene/chat_db/professional_qa/out_parser.py +++ b/pilot/scene/chat_db/professional_qa/out_parser.py @@ -1,8 +1,4 @@ -from pilot.configs.model_config import LOGDIR from pilot.out_parser.base import BaseOutputParser, T -from pilot.utils import build_logger - -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class NormalChatOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index 2078018d0..3826b35df 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -1,11 +1,10 @@ import json +import logging from typing import Dict, NamedTuple -from pilot.utils import build_logger from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) class PluginAction(NamedTuple): diff --git a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py index 4f3065334..fa9937ae6 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py @@ -1,9 +1,7 @@ -from pilot.utils import build_logger +import logging from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR - -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) class NormalChatOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 18d7e5060..332e7c080 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -60,18 +60,25 @@ class ChatKnowledge(BaseChat): async def stream_call(self): input_values = self.generate_input_values() + # Source of knowledge file + relations = input_values.get("relations") + last_output = None async for output in super().stream_call(): - # Source of knowledge file - relations = input_values.get("relations") - if ( - CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS - and type(relations) == list - and len(relations) > 0 - and hasattr(output, "text") - ): - output.text = output.text + "\trelations:" + ",".join(relations) + last_output = output yield output + if ( + CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS + and last_output + and type(relations) == list + and len(relations) > 0 + and hasattr(last_output, "text") + ): + last_output.text = ( + last_output.text + "\n\nrelations:\n\n" + ",".join(relations) + ) + yield last_output + def generate_input_values(self): if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] diff --git a/pilot/scene/chat_knowledge/v1/out_parser.py b/pilot/scene/chat_knowledge/v1/out_parser.py index ba8cd428a..ba9f1d4a7 100644 --- a/pilot/scene/chat_knowledge/v1/out_parser.py +++ b/pilot/scene/chat_knowledge/v1/out_parser.py @@ -1,9 +1,8 @@ -from pilot.utils import build_logger +import logging from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) class NormalChatOutputParser(BaseOutputParser): diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py index ba8cd428a..ba9f1d4a7 100644 --- a/pilot/scene/chat_normal/out_parser.py +++ b/pilot/scene/chat_normal/out_parser.py @@ -1,9 +1,8 @@ -from pilot.utils import build_logger +import logging from pilot.out_parser.base import BaseOutputParser, T -from pilot.configs.model_config import LOGDIR -logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") +logger = logging.getLogger(__name__) class NormalChatOutputParser(BaseOutputParser): diff --git a/pilot/server/base.py b/pilot/server/base.py index 5ebbc0003..8113b6fee 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -108,7 +108,7 @@ class WebWerverParameters(BaseParameters): }, ) log_level: Optional[str] = field( - default="INFO", + default=None, metadata={ "help": "Logging level", "valid_values": [ diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index c78258382..ab9747e36 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -33,7 +33,11 @@ from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 from pilot.commands.disply_type.show_chart_gen import static_message_img_path from pilot.model.cluster import initialize_worker_manager_in_client -from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level +from pilot.utils.utils import ( + setup_logging, + _get_logging_level, + logging_str_to_uvicorn_level, +) static_file_path = os.path.join(os.getcwd(), "server/static") @@ -111,7 +115,11 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None): ) param = WebWerverParameters(**vars(parser.parse_args(args=args))) - setup_logging(logging_level=param.log_level) + if not param.log_level: + param.log_level = _get_logging_level() + setup_logging( + "pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log" + ) # Before start system_app.before_start() diff --git a/pilot/server/knowledge/_cli/knowledge_cli.py b/pilot/server/knowledge/_cli/knowledge_cli.py index 9e9640dd5..ce988c22c 100644 --- a/pilot/server/knowledge/_cli/knowledge_cli.py +++ b/pilot/server/knowledge/_cli/knowledge_cli.py @@ -1,9 +1,12 @@ import click import logging +import os +import functools from pilot.configs.model_config import DATASETS_DIR -API_ADDRESS: str = "http://127.0.0.1:5000" +_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000" +API_ADDRESS: str = _DEFAULT_API_ADDRESS logger = logging.getLogger("dbgpt_cli") @@ -20,33 +23,44 @@ logger = logging.getLogger("dbgpt_cli") def knowledge_cli_group(address: str): """Knowledge command line tool""" global API_ADDRESS + if address == _DEFAULT_API_ADDRESS: + address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS) API_ADDRESS = address +def add_knowledge_options(func): + @click.option( + "--space_name", + required=False, + type=str, + default="default", + show_default=True, + help="Your knowledge space name", + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + @knowledge_cli_group.command() -@click.option( - "--vector_name", - required=False, - type=str, - default="default", - show_default=True, - help="Your vector store name", -) +@add_knowledge_options @click.option( "--vector_store_type", required=False, type=str, default="Chroma", show_default=True, - help="Vector store type", + help="Vector store type.", ) @click.option( - "--local_doc_dir", + "--local_doc_path", required=False, type=str, default=DATASETS_DIR, show_default=True, - help="Your document directory", + help="Your document directory or document file path.", ) @click.option( "--skip_wrong_doc", @@ -54,31 +68,165 @@ def knowledge_cli_group(address: str): type=bool, default=False, is_flag=True, - show_default=True, - help="Skip wrong document", + help="Skip wrong document.", +) +@click.option( + "--overwrite", + required=False, + type=bool, + default=False, + is_flag=True, + help="Overwrite existing document(they has same name).", ) @click.option( "--max_workers", required=False, type=int, default=None, - help="The maximum number of threads that can be used to upload document", + help="The maximum number of threads that can be used to upload document.", +) +@click.option( + "--pre_separator", + required=False, + type=str, + default=None, + help="Preseparator, this separator is used for pre-splitting before the document is " + "actually split by the text splitter. Preseparator are not included in the vectorized text. ", +) +@click.option( + "--separator", + required=False, + type=str, + default=None, + help="This is the document separator. Currently, only one separator is supported.", +) +@click.option( + "--chunk_size", + required=False, + type=int, + default=None, + help="Maximum size of chunks to split.", +) +@click.option( + "--chunk_overlap", + required=False, + type=int, + default=None, + help="Overlap in characters between chunks.", ) def load( - vector_name: str, + space_name: str, vector_store_type: str, - local_doc_dir: str, + local_doc_path: str, skip_wrong_doc: bool, + overwrite: bool, max_workers: int, + pre_separator: str, + separator: str, + chunk_size: int, + chunk_overlap: int, ): """Load your local knowledge to DB-GPT""" from pilot.server.knowledge._cli.knowledge_client import knowledge_init knowledge_init( API_ADDRESS, - vector_name, + space_name, vector_store_type, - local_doc_dir, + local_doc_path, skip_wrong_doc, + overwrite, max_workers, + pre_separator, + separator, + chunk_size, + chunk_overlap, + ) + + +@knowledge_cli_group.command() +@add_knowledge_options +@click.option( + "--doc_name", + required=False, + type=str, + default=None, + help="The document name you want to delete. If doc_name is None, this command will delete the whole space.", +) +@click.option( + "-y", + required=False, + type=bool, + default=False, + is_flag=True, + help="Confirm your choice", +) +def delete(space_name: str, doc_name: str, y: bool): + """Delete your knowledge space or document in space""" + from pilot.server.knowledge._cli.knowledge_client import knowledge_delete + + knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y) + + +@knowledge_cli_group.command() +@click.option( + "--space_name", + required=False, + type=str, + default=None, + show_default=True, + help="Your knowledge space name. If None, list all spaces", +) +@click.option( + "--doc_id", + required=False, + type=int, + default=None, + show_default=True, + help="Your document id in knowledge space. If Not None, list all chunks in current document", +) +@click.option( + "--page", + required=False, + type=int, + default=1, + show_default=True, + help="The page for every query", +) +@click.option( + "--page_size", + required=False, + type=int, + default=20, + show_default=True, + help="The page size for every query", +) +@click.option( + "--show_content", + required=False, + type=bool, + default=False, + is_flag=True, + help="Query the document content of chunks", +) +@click.option( + "--output", + required=False, + type=click.Choice(["text", "html", "csv", "latex", "json"]), + default="text", + help="The output format", +) +def list( + space_name: str, + doc_id: int, + page: int, + page_size: int, + show_content: bool, + output: str, +): + """List knowledge space""" + from pilot.server.knowledge._cli.knowledge_client import knowledge_list + + knowledge_list( + API_ADDRESS, space_name, page, page_size, doc_id, show_content, output ) diff --git a/pilot/server/knowledge/_cli/knowledge_client.py b/pilot/server/knowledge/_cli/knowledge_client.py index 4cfeea339..403065a6c 100644 --- a/pilot/server/knowledge/_cli/knowledge_client.py +++ b/pilot/server/knowledge/_cli/knowledge_client.py @@ -62,6 +62,9 @@ class KnowledgeApiClient(ApiClient): else: raise e + def space_delete(self, request: KnowledgeSpaceRequest): + return self._post("/knowledge/space/delete", data=request) + def space_list(self, request: KnowledgeSpaceRequest): return self._post("/knowledge/space/list", data=request) @@ -69,6 +72,10 @@ class KnowledgeApiClient(ApiClient): url = f"/knowledge/{space_name}/document/add" return self._post(url, data=request) + def document_delete(self, space_name: str, request: KnowledgeDocumentRequest): + url = f"/knowledge/{space_name}/document/delete" + return self._post(url, data=request) + def document_list(self, space_name: str, query_request: DocumentQueryRequest): url = f"/knowledge/{space_name}/document/list" return self._post(url, data=query_request) @@ -97,15 +104,20 @@ class KnowledgeApiClient(ApiClient): def knowledge_init( api_address: str, - vector_name: str, + space_name: str, vector_store_type: str, - local_doc_dir: str, + local_doc_path: str, skip_wrong_doc: bool, - max_workers: int = None, + overwrite: bool, + max_workers: int, + pre_separator: str, + separator: str, + chunk_size: int, + chunk_overlap: int, ): client = KnowledgeApiClient(api_address) space = KnowledgeSpaceRequest() - space.name = vector_name + space.name = space_name space.desc = "DB-GPT cli" space.vector_type = vector_store_type space.owner = "DB-GPT" @@ -124,24 +136,260 @@ def knowledge_init( def upload(filename: str): try: logger.info(f"Begin upload document: {filename} to {space.name}") - doc_id = client.document_upload( - space.name, filename, KnowledgeType.DOCUMENT.value, filename - ) - client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id])) + doc_id = None + try: + doc_id = client.document_upload( + space.name, filename, KnowledgeType.DOCUMENT.value, filename + ) + except Exception as ex: + if overwrite and "have already named" in str(ex): + logger.warn( + f"Document {filename} already exist in space {space.name}, overwrite it" + ) + client.document_delete( + space.name, KnowledgeDocumentRequest(doc_name=filename) + ) + doc_id = client.document_upload( + space.name, filename, KnowledgeType.DOCUMENT.value, filename + ) + else: + raise ex + sync_req = DocumentSyncRequest(doc_ids=[doc_id]) + if pre_separator: + sync_req.pre_separator = pre_separator + if separator: + sync_req.separators = [separator] + if chunk_size: + sync_req.chunk_size = chunk_size + if chunk_overlap: + sync_req.chunk_overlap = chunk_overlap + + client.document_sync(space.name, sync_req) + return doc_id except Exception as e: if skip_wrong_doc: logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}") else: raise e + if not os.path.exists(local_doc_path): + raise Exception(f"{local_doc_path} not exists") + with ThreadPoolExecutor(max_workers=max_workers) as pool: tasks = [] - for root, _, files in os.walk(local_doc_dir, topdown=False): - for file in files: - filename = os.path.join(root, file) - tasks.append(pool.submit(upload, filename)) + file_names = [] + if os.path.isdir(local_doc_path): + for root, _, files in os.walk(local_doc_path, topdown=False): + for file in files: + file_names.append(os.path.join(root, file)) + else: + # Single file + file_names.append(local_doc_path) + + [tasks.append(pool.submit(upload, filename)) for filename in file_names] + doc_ids = [r.result() for r in as_completed(tasks)] doc_ids = list(filter(lambda x: x, doc_ids)) if not doc_ids: logger.warn("Warning: no document to sync") return + + +from prettytable import PrettyTable + + +class _KnowledgeVisualizer: + def __init__(self, api_address: str, out_format: str): + self.client = KnowledgeApiClient(api_address) + self.out_format = out_format + self.out_kwargs = {} + if out_format == "json": + self.out_kwargs["ensure_ascii"] = False + + def print_table(self, table): + print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs)) + + def list_spaces(self): + spaces = self.client.space_list(KnowledgeSpaceRequest()) + table = PrettyTable( + ["Space ID", "Space Name", "Vector Type", "Owner", "Description"], + title="All knowledge spaces", + ) + for sp in spaces: + context = sp.get("context") + table.add_row( + [ + sp.get("id"), + sp.get("name"), + sp.get("vector_type"), + sp.get("owner"), + sp.get("desc"), + ] + ) + self.print_table(table) + + def list_documents(self, space_name: str, page: int, page_size: int): + space_data = self.client.document_list( + space_name, DocumentQueryRequest(page=page, page_size=page_size) + ) + + space_table = PrettyTable( + [ + "Space Name", + "Total Documents", + "Current Page", + "Current Size", + "Page Size", + ], + title=f"Space {space_name} description", + ) + space_table.add_row( + [space_name, space_data["total"], page, len(space_data["data"]), page_size] + ) + + table = PrettyTable( + [ + "Space Name", + "Document ID", + "Document Name", + "Type", + "Chunks", + "Last Sync", + "Status", + "Result", + ], + title=f"Documents of space {space_name}", + ) + for doc in space_data["data"]: + table.add_row( + [ + space_name, + doc.get("id"), + doc.get("doc_name"), + doc.get("doc_type"), + doc.get("chunk_size"), + doc.get("last_sync"), + doc.get("status"), + doc.get("result"), + ] + ) + if self.out_format == "text": + self.print_table(space_table) + print("") + self.print_table(table) + + def list_chunks( + self, + space_name: str, + doc_id: int, + page: int, + page_size: int, + show_content: bool, + ): + doc_data = self.client.chunk_list( + space_name, + ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size), + ) + + doc_table = PrettyTable( + [ + "Space Name", + "Document ID", + "Total Chunks", + "Current Page", + "Current Size", + "Page Size", + ], + title=f"Document {doc_id} in {space_name} description", + ) + doc_table.add_row( + [ + space_name, + doc_id, + doc_data["total"], + page, + len(doc_data["data"]), + page_size, + ] + ) + + table = PrettyTable( + ["Space Name", "Document ID", "Document Name", "Content", "Meta Data"], + title=f"chunks of document id {doc_id} in space {space_name}", + ) + for chunk in doc_data["data"]: + table.add_row( + [ + space_name, + doc_id, + chunk.get("doc_name"), + chunk.get("content") if show_content else "[Hidden]", + chunk.get("meta_info"), + ] + ) + if self.out_format == "text": + self.print_table(doc_table) + print("") + self.print_table(table) + + +def knowledge_list( + api_address: str, + space_name: str, + page: int, + page_size: int, + doc_id: int, + show_content: bool, + out_format: str, +): + visualizer = _KnowledgeVisualizer(api_address, out_format) + if not space_name: + visualizer.list_spaces() + elif not doc_id: + visualizer.list_documents(space_name, page, page_size) + else: + visualizer.list_chunks(space_name, doc_id, page, page_size, show_content) + + +def knowledge_delete( + api_address: str, space_name: str, doc_name: str, confirm: bool = False +): + client = KnowledgeApiClient(api_address) + space = KnowledgeSpaceRequest() + space.name = space_name + space_list = client.space_list(KnowledgeSpaceRequest(name=space.name)) + if not space_list: + raise Exception(f"No knowledge space name {space_name}") + + if not doc_name: + if not confirm: + # Confirm by user + user_input = ( + input( + f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: " + ) + .strip() + .lower() + ) + if user_input != "yes": + logger.warn("Delete operation cancelled.") + return + client.space_delete(space) + logger.info("Delete the whole knowledge space successfully!") + else: + if not confirm: + # Confirm by user + user_input = ( + input( + f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: " + ) + .strip() + .lower() + ) + if user_input != "yes": + logger.warn("Delete operation cancelled.") + return + client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name)) + logger.info( + f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!" + ) diff --git a/pilot/server/knowledge/api.py b/pilot/server/knowledge/api.py index 71b939924..57fadb21e 100644 --- a/pilot/server/knowledge/api.py +++ b/pilot/server/knowledge/api.py @@ -1,6 +1,7 @@ import os import shutil import tempfile +import logging from fastapi import APIRouter, File, UploadFile, Form @@ -27,6 +28,8 @@ from pilot.server.knowledge.request.request import ( from pilot.server.knowledge.request.request import KnowledgeSpaceRequest +logger = logging.getLogger(__name__) + CFG = Config() router = APIRouter() @@ -159,10 +162,10 @@ async def document_upload( @router.post("/knowledge/{space_name}/document/sync") def document_sync(space_name: str, request: DocumentSyncRequest): - print(f"Received params: {space_name}, {request}") + logger.info(f"Received params: {space_name}, {request}") try: knowledge_space_service.sync_knowledge_document( - space_name=space_name, doc_ids=request.doc_ids + space_name=space_name, sync_request=request ) return Result.succ([]) except Exception as e: diff --git a/pilot/server/knowledge/request/request.py b/pilot/server/knowledge/request/request.py index f0c47abeb..b83165c19 100644 --- a/pilot/server/knowledge/request/request.py +++ b/pilot/server/knowledge/request/request.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from pydantic import BaseModel from fastapi import UploadFile @@ -54,10 +54,25 @@ class DocumentQueryRequest(BaseModel): class DocumentSyncRequest(BaseModel): - """doc_ids: doc ids""" + """Sync request""" + """doc_ids: doc ids""" doc_ids: List + """Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter. + Preseparator are not included in the vectorized text. + """ + pre_separator: Optional[str] = None + + """Custom separators""" + separators: Optional[List[str]] = None + + """Custom chunk size""" + chunk_size: Optional[int] = None + + """Custom chunk overlap""" + chunk_overlap: Optional[int] = None + class ChunkQueryRequest(BaseModel): """id: id""" diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index d574ffdfd..bf994d22d 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -1,5 +1,5 @@ import json -import threading +import logging from datetime import datetime from pilot.vector_store.connector import VectorStoreConnector @@ -12,7 +12,6 @@ from pilot.configs.model_config import ( from pilot.component import ComponentType from pilot.utils.executor_utils import ExecutorFactory -from pilot.logs import logger from pilot.server.knowledge.chunk_db import ( DocumentChunkEntity, DocumentChunkDao, @@ -31,6 +30,7 @@ from pilot.server.knowledge.request.request import ( DocumentQueryRequest, ChunkQueryRequest, SpaceArgumentRequest, + DocumentSyncRequest, ) from enum import Enum @@ -44,6 +44,7 @@ knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() document_chunk_dao = DocumentChunkDao() +logger = logging.getLogger(__name__) CFG = Config() @@ -107,7 +108,6 @@ class KnowledgeService: res.owner = space.owner res.gmt_created = space.gmt_created res.gmt_modified = space.gmt_modified - res.owner = space.owner res.context = space.context query = KnowledgeDocumentEntity(space=space.name) doc_count = knowledge_document_dao.get_knowledge_documents_count(query) @@ -155,9 +155,10 @@ class KnowledgeService: """sync knowledge document chunk into vector store""" - def sync_knowledge_document(self, space_name, doc_ids): + def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest): from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_factory import EmbeddingFactory + from pilot.embedding_engine.pre_text_splitter import PreTextSplitter from langchain.text_splitter import ( RecursiveCharacterTextSplitter, SpacyTextSplitter, @@ -165,6 +166,7 @@ class KnowledgeService: # import langchain is very very slow!!! + doc_ids = sync_request.doc_ids for doc_id in doc_ids: query = KnowledgeDocumentEntity( id=doc_id, @@ -190,24 +192,43 @@ class KnowledgeService: if space_context is None else int(space_context["embedding"]["chunk_overlap"]) ) + if sync_request.chunk_size: + chunk_size = sync_request.chunk_size + if sync_request.chunk_overlap: + chunk_overlap = sync_request.chunk_overlap + separators = sync_request.separators or None if CFG.LANGUAGE == "en": text_splitter = RecursiveCharacterTextSplitter( + separators=separators, chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, ) else: + if separators and len(separators) > 1: + raise ValueError( + "SpacyTextSplitter do not support multiple separators" + ) try: + separator = "\n\n" if not separators else separators[0] text_splitter = SpacyTextSplitter( + separator=separator, pipeline="zh_core_web_sm", chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) except Exception: text_splitter = RecursiveCharacterTextSplitter( + separators=separators, chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) + if sync_request.pre_separator: + logger.info(f"Use preseparator, {sync_request.pre_separator}") + text_splitter = PreTextSplitter( + pre_separator=sync_request.pre_separator, + text_splitter_impl=text_splitter, + ) embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) diff --git a/pilot/speech/eleven_labs.py b/pilot/speech/eleven_labs.py index 671a3d729..f92501e09 100644 --- a/pilot/speech/eleven_labs.py +++ b/pilot/speech/eleven_labs.py @@ -1,6 +1,6 @@ """ElevenLabs speech module""" import os - +import logging import requests from pilot.configs.config import Config @@ -8,6 +8,8 @@ from pilot.speech.base import VoiceBase PLACEHOLDERS = {"your-voice-id"} +logger = logging.getLogger(__name__) + class ElevenLabsSpeech(VoiceBase): """ElevenLabs speech class""" @@ -68,7 +70,6 @@ class ElevenLabsSpeech(VoiceBase): Returns: bool: True if the request was successful, False otherwise """ - from pilot.logs import logger from playsound import playsound tts_url = ( diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 0fea50061..1c56db45b 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -1,5 +1,6 @@ import json import uuid +import logging from pilot.common.schema import DBType from pilot.component import SystemApp @@ -7,16 +8,14 @@ from pilot.configs.config import Config from pilot.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, EMBEDDING_MODEL_CONFIG, - LOGDIR, ) + from pilot.scene.base import ChatScene from pilot.scene.base_chat import BaseChat from pilot.scene.chat_factory import ChatFactory from pilot.summary.rdbms_db_summary import RdbmsSummary -from pilot.utils import build_logger - -logger = build_logger("db_summary", LOGDIR + "db_summary.log") +logger = logging.getLogger(__name__) CFG = Config() chat_factory = ChatFactory() diff --git a/pilot/utils/__init__.py b/pilot/utils/__init__.py index 8a84bc0ec..e64aa1ff9 100644 --- a/pilot/utils/__init__.py +++ b/pilot/utils/__init__.py @@ -1,6 +1,5 @@ from .utils import ( get_gpu_memory, - build_logger, StreamToLogger, disable_torch_init, pretty_print_semaphore, diff --git a/pilot/utils/api_utils.py b/pilot/utils/api_utils.py index 93b280188..1fd3499d2 100644 --- a/pilot/utils/api_utils.py +++ b/pilot/utils/api_utils.py @@ -5,6 +5,8 @@ from dataclasses import is_dataclass, asdict T = TypeVar("T") +logger = logging.getLogger(__name__) + def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: import typing_inspect @@ -21,7 +23,7 @@ def _build_request(self, func, path, method, *args, **kwargs): raise TypeError("Return type must be annotated in the decorated function.") actual_dataclass = _extract_dataclass_from_generic(return_type) - logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") + logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") if not actual_dataclass: actual_dataclass = return_type sig = signature(func) @@ -55,7 +57,7 @@ def _build_request(self, func, path, method, *args, **kwargs): else: # For GET, DELETE, etc. request_params["params"] = request_data - logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") + logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") return return_type, actual_dataclass, request_params diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py index f3cac70d2..b72745a33 100644 --- a/pilot/utils/utils.py +++ b/pilot/utils/utils.py @@ -20,7 +20,7 @@ def _get_logging_level() -> str: return os.getenv("DBGPT_LOG_LEVEL", "INFO") -def setup_logging(logging_level=None, logger_name: str = None): +def setup_logging_level(logging_level=None, logger_name: str = None): if not logging_level: logging_level = _get_logging_level() if type(logging_level) is str: @@ -32,6 +32,19 @@ def setup_logging(logging_level=None, logger_name: str = None): logging.basicConfig(level=logging_level, encoding="utf-8") +def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None): + if not logging_level: + logging_level = _get_logging_level() + logger = _build_logger(logger_name, logging_level, logger_filename) + try: + import coloredlogs + + color_level = logging_level if logging_level else "INFO" + coloredlogs.install(level=color_level, logger=logger) + except ImportError: + pass + + def get_gpu_memory(max_gpus=None): import torch @@ -53,7 +66,7 @@ def get_gpu_memory(max_gpus=None): return gpu_memory -def build_logger(logger_name, logger_filename): +def _build_logger(logger_name, logging_level=None, logger_filename: str = None): global handler formatter = logging.Formatter( @@ -63,7 +76,7 @@ def build_logger(logger_name, logger_filename): # Set the format of root handlers if not logging.getLogger().handlers: - setup_logging() + setup_logging_level(logging_level=logging_level) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers @@ -78,7 +91,7 @@ def build_logger(logger_name, logger_filename): # sys.stderr = sl # Add a file handler for all loggers - if handler is None: + if handler is None and logger_filename: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( @@ -89,11 +102,9 @@ def build_logger(logger_name, logger_filename): for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) - setup_logging() - # Get logger logger = logging.getLogger(logger_name) - setup_logging(logger_name=logger_name) + setup_logging_level(logging_level=logging_level, logger_name=logger_name) return logger diff --git a/pilot/vector_store/chroma_store.py b/pilot/vector_store/chroma_store.py index 7b2866cae..58ae88bf6 100644 --- a/pilot/vector_store/chroma_store.py +++ b/pilot/vector_store/chroma_store.py @@ -1,11 +1,13 @@ import os +import logging from typing import Any from chromadb.config import Settings from chromadb import PersistentClient -from pilot.logs import logger from pilot.vector_store.base import VectorStoreBase +logger = logging.getLogger(__name__) + class ChromaStore(VectorStoreBase): """chroma database""" diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 1eb08b1e4..104cdc5fc 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,12 +1,13 @@ from __future__ import annotations - +import logging from typing import Any, Iterable, List, Optional, Tuple from pymilvus import Collection, DataType, connections, utility -from pilot.logs import logger from pilot.vector_store.base import VectorStoreBase +logger = logging.getLogger(__name__) + class MilvusStore(VectorStoreBase): """Milvus database""" diff --git a/pilot/vector_store/weaviate_store.py b/pilot/vector_store/weaviate_store.py index 14c26c33d..9df9b39b1 100644 --- a/pilot/vector_store/weaviate_store.py +++ b/pilot/vector_store/weaviate_store.py @@ -1,5 +1,6 @@ import os import json +import logging import weaviate from langchain.schema import Document from langchain.vectorstores import Weaviate @@ -7,9 +8,9 @@ from weaviate.exceptions import WeaviateBaseError from pilot.configs.config import Config from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH -from pilot.logs import logger from pilot.vector_store.base import VectorStoreBase +logger = logging.getLogger(__name__) CFG = Config() diff --git a/setup.py b/setup.py index 6aa97b61d..4808fc4e5 100644 --- a/setup.py +++ b/setup.py @@ -287,6 +287,7 @@ def core_requires(): ] setup_spec.extras["framework"] = [ + "coloredlogs", "httpx", "sqlparse==0.4.4", "seaborn",