feat(ChatKnowledge): Add custom text separators and refactor log configuration

This commit is contained in:
FangYin Cheng 2023-09-28 11:54:58 +08:00
parent 20bdddec51
commit 5dfe611478
52 changed files with 705 additions and 158 deletions

View File

@ -55,6 +55,8 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh #EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500 KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5 KNOWLEDGE_SEARCH_TOP_SIZE=5
# Control whether to display the source document of knowledge on the front end.
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs ## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs ## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=all-MiniLM-L6-v2 # EMBEDDING_MODEL=all-MiniLM-L6-v2
@ -155,3 +157,5 @@ SUMMARY_CONFIG=FAST
#*******************************************************************# #*******************************************************************#
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET # FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
DBGPT_LOG_LEVEL=INFO DBGPT_LOG_LEVEL=INFO
# LOG dir, default: ./logs
#DBGPT_LOG_DIR=

View File

@ -1,6 +1,7 @@
import contextlib import contextlib
import json import json
from typing import Any, Dict from typing import Any, Dict
import logging
from colorama import Fore from colorama import Fore
from regex import regex from regex import regex
@ -11,9 +12,9 @@ from pilot.json_utils.json_fix_general import (
balance_braces, balance_braces,
fix_invalid_escape, fix_invalid_escape,
) )
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -2,14 +2,15 @@
import io import io
import uuid import uuid
from base64 import b64decode from base64 import b64decode
import logging
import requests import requests
from PIL import Image from PIL import Image
from pilot.commands.command_mange import command from pilot.commands.command_mange import command
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -1,5 +1,5 @@
import logging
from pandas import DataFrame from pandas import DataFrame
from pilot.commands.command_mange import command from pilot.commands.command_mange import command
from pilot.configs.config import Config from pilot.configs.config import Config
import pandas as pd import pandas as pd
@ -13,11 +13,10 @@ import matplotlib.pyplot as plt
import matplotlib.ticker as mtick import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager from matplotlib.font_manager import FontManager
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config() CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
logger = logging.getLogger(__name__)
static_message_img_path = os.path.join(os.getcwd(), "message/img") static_message_img_path = os.path.join(os.getcwd(), "message/img")

View File

@ -1,14 +1,13 @@
import logging
import pandas as pd import pandas as pd
from pandas import DataFrame from pandas import DataFrame
from pilot.commands.command_mange import command from pilot.commands.command_mange import command
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config() CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = logging.getLogger(__name__)
@command( @command(

View File

@ -1,13 +1,12 @@
import logging
import pandas as pd import pandas as pd
from pandas import DataFrame from pandas import DataFrame
from pilot.commands.command_mange import command from pilot.commands.command_mange import command
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config() CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = logging.getLogger(__name__)
@command( @command(

View File

@ -8,6 +8,7 @@ import zipfile
import requests import requests
import threading import threading
import datetime import datetime
import logging
from pathlib import Path from pathlib import Path
from typing import List, TYPE_CHECKING from typing import List, TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
@ -17,7 +18,8 @@ import requests
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate from auto_gpt_plugin_template import AutoGPTPluginTemplate

View File

@ -32,7 +32,7 @@ class Config(metaclass=Singleton):
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1)) # self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
self.execute_local_commands = ( self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
) )
# User agent header to use when making HTTP requests # User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if # Some websites might just completely deny request with an error code if
@ -64,7 +64,7 @@ class Config(metaclass=Singleton):
self.milvus_username = os.getenv("MILVUS_USERNAME") self.milvus_username = os.getenv("MILVUS_USERNAME")
self.milvus_password = os.getenv("MILVUS_PASSWORD") self.milvus_password = os.getenv("MILVUS_PASSWORD")
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt") self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True" self.milvus_secure = os.getenv("MILVUS_SECURE", "False").lower() == "true"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y") self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n") self.exit_key = os.getenv("EXIT_KEY", "n")
@ -98,7 +98,7 @@ class Config(metaclass=Singleton):
self.disabled_command_categories = [] self.disabled_command_categories = []
self.execute_local_commands = ( self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True" os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
) )
### message stor file ### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message") self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
@ -107,7 +107,7 @@ class Config(metaclass=Singleton):
self.plugins: List["AutoGPTPluginTemplate"] = [] self.plugins: List["AutoGPTPluginTemplate"] = []
self.plugins_openai = [] self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True" self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard") self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
@ -124,10 +124,10 @@ class Config(metaclass=Singleton):
self.plugins_denylist = [] self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration ### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = ( self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True" os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
) )
self.NATIVE_SQL_CAN_RUN_WRITE = ( self.NATIVE_SQL_CAN_RUN_WRITE = (
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True" os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
) )
### default Local database connection configuration ### default Local database connection configuration
@ -170,8 +170,8 @@ class Config(metaclass=Singleton):
# QLoRA # QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True") self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True" self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True" self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT: if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
self.IS_LOAD_8BIT = False self.IS_LOAD_8BIT = False
# In order to be compatible with the new and old model parameter design # In order to be compatible with the new and old model parameter design
@ -187,7 +187,9 @@ class Config(metaclass=Singleton):
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000) os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
) )
### Control whether to display the source document of knowledge on the front end. ### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
)
### SUMMARY_CONFIG Configuration ### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST") self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

View File

@ -3,13 +3,12 @@
import os import os
# import nltk
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models") MODEL_PATH = os.path.join(ROOT_PATH, "models")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot") PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store") VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs") LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets") DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data") DATA_DIR = os.path.join(PILOT_PATH, "data")
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path # nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path

View File

@ -1,10 +1,11 @@
import logging
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -1,5 +1,12 @@
from pilot.embedding_engine.source_embedding import SourceEmbedding, register from pilot.embedding_engine.source_embedding import SourceEmbedding, register
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.knowledge_type import KnowledgeType from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"] __all__ = [
"SourceEmbedding",
"register",
"EmbeddingEngine",
"KnowledgeType",
"PreTextSplitter",
]

View File

@ -0,0 +1,30 @@
from typing import Iterable, List
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
def _single_document_split(
document: Document, pre_separator: str
) -> Iterable[Document]:
page_content = document.page_content
for i, content in enumerate(page_content.split(pre_separator)):
metadata = document.metadata.copy()
if "source" in metadata:
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
yield Document(page_content=content, metadata=metadata)
class PreTextSplitter(TextSplitter):
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
self.pre_separator = pre_separator
self._impl = text_splitter_impl
def split_text(self, text: str) -> List[str]:
return self._impl.split_text(text)
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
def generator() -> Iterable[Document]:
for doc in documents:
yield from _single_document_split(doc, pre_separator=self.pre_separator)
return self._impl.split_documents(generator())

View File

@ -5,12 +5,13 @@ from __future__ import annotations
import contextlib import contextlib
import json import json
import re import re
import logging
from typing import Optional from typing import Optional
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.json_utils.utilities import extract_char_position from pilot.json_utils.utilities import extract_char_position
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -3,12 +3,14 @@ import json
import os.path import os.path
import re import re
import json import json
import logging
from datetime import datetime from datetime import datetime
from jsonschema import Draft7Validator from jsonschema import Draft7Validator
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1" LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"

View File

@ -3,6 +3,7 @@
import os import os
import re import re
import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Callable, Type from typing import List, Tuple, Callable, Type
from functools import cache from functools import cache
@ -19,8 +20,8 @@ from pilot.model.parameter import (
) )
from pilot.configs.model_config import get_device from pilot.configs.model_config import get_device
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -13,6 +13,9 @@ from pilot.utils.api_utils import (
_api_remote as api_remote, _api_remote as api_remote,
_sync_api_remote as sync_api_remote, _sync_api_remote as sync_api_remote,
) )
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__)
class BaseModelController(BaseComponent, ABC): class BaseModelController(BaseComponent, ABC):
@ -59,7 +62,7 @@ class LocalModelController(BaseModelController):
async def get_all_instances( async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]: ) -> List[ModelInstance]:
logging.info( logger.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}" f"Get all instances with {model_name}, healthy_only: {healthy_only}"
) )
if not model_name: if not model_name:
@ -178,6 +181,13 @@ def run_model_controller():
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass( controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
ModelControllerParameters, env_prefix=env_prefix ModelControllerParameters, env_prefix=env_prefix
) )
setup_logging(
"pilot",
logging_level=controller_params.log_level,
logger_filename="dbgpt_model_controller.log",
)
initialize_controller(host=controller_params.host, port=controller_params.port) initialize_controller(host=controller_params.host, port=controller_params.port)

View File

@ -120,7 +120,9 @@ class DefaultModelWorker(ModelWorker):
text=output, error_code=0, model_context=model_context text=output, error_code=0, model_context=model_context
) )
yield model_output yield model_output
print(f"\n\nfull stream output:\n{previous_response}") print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
except Exception as e: except Exception as e:
# Check if the exception is a torch.cuda.CudaError and if torch was imported. # Check if the exception is a torch.cuda.CudaError and if torch was imported.
if torch_imported and isinstance(e, torch.cuda.CudaError): if torch_imported and isinstance(e, torch.cuda.CudaError):

View File

@ -36,6 +36,7 @@ from pilot.utils.parameter_utils import (
ParameterDescription, ParameterDescription,
_dict_to_command_args, _dict_to_command_args,
) )
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -885,6 +886,12 @@ def run_worker_manager(
model_name=model_name, model_path=model_path, standalone=standalone, port=port model_name=model_name, model_path=model_path, standalone=standalone, port=port
) )
setup_logging(
"pilot",
logging_level=worker_params.log_level,
logger_filename="dbgpt_model_worker_manager.log",
)
embedded_mod = True embedded_mod = True
logger.info(f"Worker params: {worker_params}") logger.info(f"Worker params: {worker_params}")
if not app: if not app:

View File

@ -3,11 +3,13 @@ Fork from text-generation-webui https://github.com/oobabooga/text-generation-web
""" """
import re import re
from typing import Dict from typing import Dict
import logging
import torch import torch
import llama_cpp import llama_cpp
from pilot.model.parameter import LlamaCppModelParameters from pilot.model.parameter import LlamaCppModelParameters
from pilot.logs import logger
logger = logging.getLogger(__name__)
if torch.cuda.is_available() and not torch.version.hip: if torch.cuda.is_available() and not torch.version.hip:
try: try:

View File

@ -3,6 +3,7 @@
from typing import Optional, Dict from typing import Optional, Dict
import logging
from pilot.configs.model_config import get_device from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.parameter import ( from pilot.model.parameter import (
@ -12,7 +13,8 @@ from pilot.model.parameter import (
) )
from pilot.utils import get_gpu_memory from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
from pilot.logs import logger
logger = logging.getLogger(__name__)
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):

View File

@ -31,6 +31,21 @@ class ModelControllerParameters(BaseParameters):
daemon: Optional[bool] = field( daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Controller in background"} default=False, metadata={"help": "Run Model Controller in background"}
) )
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
@dataclass @dataclass
@ -85,6 +100,22 @@ class ModelWorkerParameters(BaseModelParameters):
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"} default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
) )
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
@dataclass @dataclass
class BaseEmbeddingModelParameters(BaseModelParameters): class BaseEmbeddingModelParameters(BaseModelParameters):

View File

@ -38,8 +38,6 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
@ -409,7 +407,7 @@ async def model_types(controller: BaseModelController = Depends(get_model_contro
@router.get("/v1/model/supports") @router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)): async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports") logger.info(f"/controller/model/supports")
try: try:
models = await worker_manager.supported_models() models = await worker_manager.supported_models()

View File

@ -6,12 +6,11 @@ from fastapi import (
) )
from typing import List from typing import List
import logging
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.openapi.api_view_model import ( from pilot.openapi.api_view_model import (
Result, Result,
@ -34,7 +33,8 @@ from pilot.scene.chat_db.data_loader import DbDataLoader
router = APIRouter() router = APIRouter()
CFG = Config() CFG = Config()
CHAT_FACTORY = ChatFactory() CHAT_FACTORY = ChatFactory()
logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
logger = logging.getLogger(__name__)
@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) @router.get("/v1/editor/db/tables", response_model=Result[DbTable])

View File

@ -2,18 +2,17 @@ from __future__ import annotations
import json import json
from abc import ABC from abc import ABC
import logging
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union from typing import Any, Dict, TypeVar, Union
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.utils import build_logger
T = TypeVar("T") T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput] ResponseTye = Union[str, bytes, ModelOutput]
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -1,11 +1,11 @@
import datetime import datetime
import traceback import traceback
import warnings import warnings
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Dict from typing import Any, List, Dict
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
@ -14,10 +14,10 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, get_or_create_event_loop from pilot.utils import get_or_create_event_loop
from pydantic import Extra from pydantic import Extra
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") logger = logging.getLogger(__name__)
headers = {"User-Agent": "dbgpt Client"} headers = {"User-Agent": "dbgpt Client"}
CFG = Config() CFG = Config()

View File

@ -1,13 +1,12 @@
from typing import List from typing import List
from decimal import Decimal from decimal import Decimal
import logging
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
CFG = Config() CFG = Config()
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log") logger = logging.getLogger(__name__)
class DashboardDataLoader: class DashboardDataLoader:

View File

@ -1,8 +1,8 @@
import json import json
import logging
from typing import NamedTuple, List from typing import NamedTuple, List
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
@ -13,7 +13,7 @@ class ChartItem(NamedTuple):
showcase: str showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log") logger = logging.getLogger(__name__)
class ChatDashboardOutputParser(BaseOutputParser): class ChatDashboardOutputParser(BaseOutputParser):

View File

@ -1,11 +1,7 @@
import json import json
import re import logging
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
CFG = Config() CFG = Config()
@ -17,7 +13,7 @@ class ExcelAnalyzeResponse(NamedTuple):
display: str display: str
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log") logger = logging.getLogger(__name__)
class ChatExcelOutputParser(BaseOutputParser): class ChatExcelOutputParser(BaseOutputParser):

View File

@ -1,11 +1,7 @@
import json import json
import re import logging
from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, List from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
CFG = Config() CFG = Config()
@ -17,7 +13,7 @@ class ExcelResponse(NamedTuple):
plans: List plans: List
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log") logger = logging.getLogger(__name__)
class LearningExcelOutputParser(BaseOutputParser): class LearningExcelOutputParser(BaseOutputParser):

View File

@ -1,8 +1,7 @@
import json import json
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
from pilot.utils import build_logger import logging
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader from pilot.scene.chat_db.data_loader import DbDataLoader
@ -14,7 +13,7 @@ class SqlAction(NamedTuple):
thoughts: Dict thoughts: Dict
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = logging.getLogger(__name__)
class DbChatOutputParser(BaseOutputParser): class DbChatOutputParser(BaseOutputParser):

View File

@ -1,8 +1,4 @@
from pilot.configs.model_config import LOGDIR
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.utils import build_logger
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):

View File

@ -1,11 +1,10 @@
import json import json
import logging
from typing import Dict, NamedTuple from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = logging.getLogger(__name__)
class PluginAction(NamedTuple): class PluginAction(NamedTuple):

View File

@ -1,9 +1,7 @@
from pilot.utils import build_logger import logging
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = logging.getLogger(__name__)
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):

View File

@ -60,17 +60,24 @@ class ChatKnowledge(BaseChat):
async def stream_call(self): async def stream_call(self):
input_values = self.generate_input_values() input_values = self.generate_input_values()
async for output in super().stream_call():
# Source of knowledge file # Source of knowledge file
relations = input_values.get("relations") relations = input_values.get("relations")
last_output = None
async for output in super().stream_call():
last_output = output
yield output
if ( if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and last_output
and type(relations) == list and type(relations) == list
and len(relations) > 0 and len(relations) > 0
and hasattr(output, "text") and hasattr(last_output, "text")
): ):
output.text = output.text + "\trelations:" + ",".join(relations) last_output.text = (
yield output last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
)
yield last_output
def generate_input_values(self): def generate_input_values(self):
if self.space_context: if self.space_context:

View File

@ -1,9 +1,8 @@
from pilot.utils import build_logger import logging
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):

View File

@ -1,9 +1,8 @@
from pilot.utils import build_logger import logging
from pilot.out_parser.base import BaseOutputParser, T from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser): class NormalChatOutputParser(BaseOutputParser):

View File

@ -108,7 +108,7 @@ class WebWerverParameters(BaseParameters):
}, },
) )
log_level: Optional[str] = field( log_level: Optional[str] = field(
default="INFO", default=None,
metadata={ metadata={
"help": "Logging level", "help": "Logging level",
"valid_values": [ "valid_values": [

View File

@ -33,7 +33,11 @@ from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1 from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.cluster import initialize_worker_manager_in_client from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level from pilot.utils.utils import (
setup_logging,
_get_logging_level,
logging_str_to_uvicorn_level,
)
static_file_path = os.path.join(os.getcwd(), "server/static") static_file_path = os.path.join(os.getcwd(), "server/static")
@ -111,7 +115,11 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
) )
param = WebWerverParameters(**vars(parser.parse_args(args=args))) param = WebWerverParameters(**vars(parser.parse_args(args=args)))
setup_logging(logging_level=param.log_level) if not param.log_level:
param.log_level = _get_logging_level()
setup_logging(
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
)
# Before start # Before start
system_app.before_start() system_app.before_start()

View File

@ -1,9 +1,12 @@
import click import click
import logging import logging
import os
import functools
from pilot.configs.model_config import DATASETS_DIR from pilot.configs.model_config import DATASETS_DIR
API_ADDRESS: str = "http://127.0.0.1:5000" _DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
API_ADDRESS: str = _DEFAULT_API_ADDRESS
logger = logging.getLogger("dbgpt_cli") logger = logging.getLogger("dbgpt_cli")
@ -20,33 +23,44 @@ logger = logging.getLogger("dbgpt_cli")
def knowledge_cli_group(address: str): def knowledge_cli_group(address: str):
"""Knowledge command line tool""" """Knowledge command line tool"""
global API_ADDRESS global API_ADDRESS
if address == _DEFAULT_API_ADDRESS:
address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS)
API_ADDRESS = address API_ADDRESS = address
@knowledge_cli_group.command() def add_knowledge_options(func):
@click.option( @click.option(
"--vector_name", "--space_name",
required=False, required=False,
type=str, type=str,
default="default", default="default",
show_default=True, show_default=True,
help="Your vector store name", help="Your knowledge space name",
) )
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@knowledge_cli_group.command()
@add_knowledge_options
@click.option( @click.option(
"--vector_store_type", "--vector_store_type",
required=False, required=False,
type=str, type=str,
default="Chroma", default="Chroma",
show_default=True, show_default=True,
help="Vector store type", help="Vector store type.",
) )
@click.option( @click.option(
"--local_doc_dir", "--local_doc_path",
required=False, required=False,
type=str, type=str,
default=DATASETS_DIR, default=DATASETS_DIR,
show_default=True, show_default=True,
help="Your document directory", help="Your document directory or document file path.",
) )
@click.option( @click.option(
"--skip_wrong_doc", "--skip_wrong_doc",
@ -54,31 +68,165 @@ def knowledge_cli_group(address: str):
type=bool, type=bool,
default=False, default=False,
is_flag=True, is_flag=True,
show_default=True, help="Skip wrong document.",
help="Skip wrong document", )
@click.option(
"--overwrite",
required=False,
type=bool,
default=False,
is_flag=True,
help="Overwrite existing document(they has same name).",
) )
@click.option( @click.option(
"--max_workers", "--max_workers",
required=False, required=False,
type=int, type=int,
default=None, default=None,
help="The maximum number of threads that can be used to upload document", help="The maximum number of threads that can be used to upload document.",
)
@click.option(
"--pre_separator",
required=False,
type=str,
default=None,
help="Preseparator, this separator is used for pre-splitting before the document is "
"actually split by the text splitter. Preseparator are not included in the vectorized text. ",
)
@click.option(
"--separator",
required=False,
type=str,
default=None,
help="This is the document separator. Currently, only one separator is supported.",
)
@click.option(
"--chunk_size",
required=False,
type=int,
default=None,
help="Maximum size of chunks to split.",
)
@click.option(
"--chunk_overlap",
required=False,
type=int,
default=None,
help="Overlap in characters between chunks.",
) )
def load( def load(
vector_name: str, space_name: str,
vector_store_type: str, vector_store_type: str,
local_doc_dir: str, local_doc_path: str,
skip_wrong_doc: bool, skip_wrong_doc: bool,
overwrite: bool,
max_workers: int, max_workers: int,
pre_separator: str,
separator: str,
chunk_size: int,
chunk_overlap: int,
): ):
"""Load your local knowledge to DB-GPT""" """Load your local knowledge to DB-GPT"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_init from pilot.server.knowledge._cli.knowledge_client import knowledge_init
knowledge_init( knowledge_init(
API_ADDRESS, API_ADDRESS,
vector_name, space_name,
vector_store_type, vector_store_type,
local_doc_dir, local_doc_path,
skip_wrong_doc, skip_wrong_doc,
overwrite,
max_workers, max_workers,
pre_separator,
separator,
chunk_size,
chunk_overlap,
)
@knowledge_cli_group.command()
@add_knowledge_options
@click.option(
"--doc_name",
required=False,
type=str,
default=None,
help="The document name you want to delete. If doc_name is None, this command will delete the whole space.",
)
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm your choice",
)
def delete(space_name: str, doc_name: str, y: bool):
"""Delete your knowledge space or document in space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_delete
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
@knowledge_cli_group.command()
@click.option(
"--space_name",
required=False,
type=str,
default=None,
show_default=True,
help="Your knowledge space name. If None, list all spaces",
)
@click.option(
"--doc_id",
required=False,
type=int,
default=None,
show_default=True,
help="Your document id in knowledge space. If Not None, list all chunks in current document",
)
@click.option(
"--page",
required=False,
type=int,
default=1,
show_default=True,
help="The page for every query",
)
@click.option(
"--page_size",
required=False,
type=int,
default=20,
show_default=True,
help="The page size for every query",
)
@click.option(
"--show_content",
required=False,
type=bool,
default=False,
is_flag=True,
help="Query the document content of chunks",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
def list(
space_name: str,
doc_id: int,
page: int,
page_size: int,
show_content: bool,
output: str,
):
"""List knowledge space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_list
knowledge_list(
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output
) )

View File

@ -62,6 +62,9 @@ class KnowledgeApiClient(ApiClient):
else: else:
raise e raise e
def space_delete(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/delete", data=request)
def space_list(self, request: KnowledgeSpaceRequest): def space_list(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/list", data=request) return self._post("/knowledge/space/list", data=request)
@ -69,6 +72,10 @@ class KnowledgeApiClient(ApiClient):
url = f"/knowledge/{space_name}/document/add" url = f"/knowledge/{space_name}/document/add"
return self._post(url, data=request) return self._post(url, data=request)
def document_delete(self, space_name: str, request: KnowledgeDocumentRequest):
url = f"/knowledge/{space_name}/document/delete"
return self._post(url, data=request)
def document_list(self, space_name: str, query_request: DocumentQueryRequest): def document_list(self, space_name: str, query_request: DocumentQueryRequest):
url = f"/knowledge/{space_name}/document/list" url = f"/knowledge/{space_name}/document/list"
return self._post(url, data=query_request) return self._post(url, data=query_request)
@ -97,15 +104,20 @@ class KnowledgeApiClient(ApiClient):
def knowledge_init( def knowledge_init(
api_address: str, api_address: str,
vector_name: str, space_name: str,
vector_store_type: str, vector_store_type: str,
local_doc_dir: str, local_doc_path: str,
skip_wrong_doc: bool, skip_wrong_doc: bool,
max_workers: int = None, overwrite: bool,
max_workers: int,
pre_separator: str,
separator: str,
chunk_size: int,
chunk_overlap: int,
): ):
client = KnowledgeApiClient(api_address) client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest() space = KnowledgeSpaceRequest()
space.name = vector_name space.name = space_name
space.desc = "DB-GPT cli" space.desc = "DB-GPT cli"
space.vector_type = vector_store_type space.vector_type = vector_store_type
space.owner = "DB-GPT" space.owner = "DB-GPT"
@ -124,24 +136,260 @@ def knowledge_init(
def upload(filename: str): def upload(filename: str):
try: try:
logger.info(f"Begin upload document: {filename} to {space.name}") logger.info(f"Begin upload document: {filename} to {space.name}")
doc_id = None
try:
doc_id = client.document_upload( doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename space.name, filename, KnowledgeType.DOCUMENT.value, filename
) )
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id])) except Exception as ex:
if overwrite and "have already named" in str(ex):
logger.warn(
f"Document {filename} already exist in space {space.name}, overwrite it"
)
client.document_delete(
space.name, KnowledgeDocumentRequest(doc_name=filename)
)
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
else:
raise ex
sync_req = DocumentSyncRequest(doc_ids=[doc_id])
if pre_separator:
sync_req.pre_separator = pre_separator
if separator:
sync_req.separators = [separator]
if chunk_size:
sync_req.chunk_size = chunk_size
if chunk_overlap:
sync_req.chunk_overlap = chunk_overlap
client.document_sync(space.name, sync_req)
return doc_id
except Exception as e: except Exception as e:
if skip_wrong_doc: if skip_wrong_doc:
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}") logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else: else:
raise e raise e
if not os.path.exists(local_doc_path):
raise Exception(f"{local_doc_path} not exists")
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
tasks = [] tasks = []
for root, _, files in os.walk(local_doc_dir, topdown=False): file_names = []
if os.path.isdir(local_doc_path):
for root, _, files in os.walk(local_doc_path, topdown=False):
for file in files: for file in files:
filename = os.path.join(root, file) file_names.append(os.path.join(root, file))
tasks.append(pool.submit(upload, filename)) else:
# Single file
file_names.append(local_doc_path)
[tasks.append(pool.submit(upload, filename)) for filename in file_names]
doc_ids = [r.result() for r in as_completed(tasks)] doc_ids = [r.result() for r in as_completed(tasks)]
doc_ids = list(filter(lambda x: x, doc_ids)) doc_ids = list(filter(lambda x: x, doc_ids))
if not doc_ids: if not doc_ids:
logger.warn("Warning: no document to sync") logger.warn("Warning: no document to sync")
return return
from prettytable import PrettyTable
class _KnowledgeVisualizer:
def __init__(self, api_address: str, out_format: str):
self.client = KnowledgeApiClient(api_address)
self.out_format = out_format
self.out_kwargs = {}
if out_format == "json":
self.out_kwargs["ensure_ascii"] = False
def print_table(self, table):
print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs))
def list_spaces(self):
spaces = self.client.space_list(KnowledgeSpaceRequest())
table = PrettyTable(
["Space ID", "Space Name", "Vector Type", "Owner", "Description"],
title="All knowledge spaces",
)
for sp in spaces:
context = sp.get("context")
table.add_row(
[
sp.get("id"),
sp.get("name"),
sp.get("vector_type"),
sp.get("owner"),
sp.get("desc"),
]
)
self.print_table(table)
def list_documents(self, space_name: str, page: int, page_size: int):
space_data = self.client.document_list(
space_name, DocumentQueryRequest(page=page, page_size=page_size)
)
space_table = PrettyTable(
[
"Space Name",
"Total Documents",
"Current Page",
"Current Size",
"Page Size",
],
title=f"Space {space_name} description",
)
space_table.add_row(
[space_name, space_data["total"], page, len(space_data["data"]), page_size]
)
table = PrettyTable(
[
"Space Name",
"Document ID",
"Document Name",
"Type",
"Chunks",
"Last Sync",
"Status",
"Result",
],
title=f"Documents of space {space_name}",
)
for doc in space_data["data"]:
table.add_row(
[
space_name,
doc.get("id"),
doc.get("doc_name"),
doc.get("doc_type"),
doc.get("chunk_size"),
doc.get("last_sync"),
doc.get("status"),
doc.get("result"),
]
)
if self.out_format == "text":
self.print_table(space_table)
print("")
self.print_table(table)
def list_chunks(
self,
space_name: str,
doc_id: int,
page: int,
page_size: int,
show_content: bool,
):
doc_data = self.client.chunk_list(
space_name,
ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size),
)
doc_table = PrettyTable(
[
"Space Name",
"Document ID",
"Total Chunks",
"Current Page",
"Current Size",
"Page Size",
],
title=f"Document {doc_id} in {space_name} description",
)
doc_table.add_row(
[
space_name,
doc_id,
doc_data["total"],
page,
len(doc_data["data"]),
page_size,
]
)
table = PrettyTable(
["Space Name", "Document ID", "Document Name", "Content", "Meta Data"],
title=f"chunks of document id {doc_id} in space {space_name}",
)
for chunk in doc_data["data"]:
table.add_row(
[
space_name,
doc_id,
chunk.get("doc_name"),
chunk.get("content") if show_content else "[Hidden]",
chunk.get("meta_info"),
]
)
if self.out_format == "text":
self.print_table(doc_table)
print("")
self.print_table(table)
def knowledge_list(
api_address: str,
space_name: str,
page: int,
page_size: int,
doc_id: int,
show_content: bool,
out_format: str,
):
visualizer = _KnowledgeVisualizer(api_address, out_format)
if not space_name:
visualizer.list_spaces()
elif not doc_id:
visualizer.list_documents(space_name, page, page_size)
else:
visualizer.list_chunks(space_name, doc_id, page, page_size, show_content)
def knowledge_delete(
api_address: str, space_name: str, doc_name: str, confirm: bool = False
):
client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest()
space.name = space_name
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
if not space_list:
raise Exception(f"No knowledge space name {space_name}")
if not doc_name:
if not confirm:
# Confirm by user
user_input = (
input(
f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: "
)
.strip()
.lower()
)
if user_input != "yes":
logger.warn("Delete operation cancelled.")
return
client.space_delete(space)
logger.info("Delete the whole knowledge space successfully!")
else:
if not confirm:
# Confirm by user
user_input = (
input(
f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: "
)
.strip()
.lower()
)
if user_input != "yes":
logger.warn("Delete operation cancelled.")
return
client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name))
logger.info(
f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!"
)

View File

@ -1,6 +1,7 @@
import os import os
import shutil import shutil
import tempfile import tempfile
import logging
from fastapi import APIRouter, File, UploadFile, Form from fastapi import APIRouter, File, UploadFile, Form
@ -27,6 +28,8 @@ from pilot.server.knowledge.request.request import (
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
router = APIRouter() router = APIRouter()
@ -159,10 +162,10 @@ async def document_upload(
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest): def document_sync(space_name: str, request: DocumentSyncRequest):
print(f"Received params: {space_name}, {request}") logger.info(f"Received params: {space_name}, {request}")
try: try:
knowledge_space_service.sync_knowledge_document( knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids space_name=space_name, sync_request=request
) )
return Result.succ([]) return Result.succ([])
except Exception as e: except Exception as e:

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
from fastapi import UploadFile from fastapi import UploadFile
@ -54,10 +54,25 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel): class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids""" """Sync request"""
"""doc_ids: doc ids"""
doc_ids: List doc_ids: List
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
Preseparator are not included in the vectorized text.
"""
pre_separator: Optional[str] = None
"""Custom separators"""
separators: Optional[List[str]] = None
"""Custom chunk size"""
chunk_size: Optional[int] = None
"""Custom chunk overlap"""
chunk_overlap: Optional[int] = None
class ChunkQueryRequest(BaseModel): class ChunkQueryRequest(BaseModel):
"""id: id""" """id: id"""

View File

@ -1,5 +1,5 @@
import json import json
import threading import logging
from datetime import datetime from datetime import datetime
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@ -12,7 +12,6 @@ from pilot.configs.model_config import (
from pilot.component import ComponentType from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import ( from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity, DocumentChunkEntity,
DocumentChunkDao, DocumentChunkDao,
@ -31,6 +30,7 @@ from pilot.server.knowledge.request.request import (
DocumentQueryRequest, DocumentQueryRequest,
ChunkQueryRequest, ChunkQueryRequest,
SpaceArgumentRequest, SpaceArgumentRequest,
DocumentSyncRequest,
) )
from enum import Enum from enum import Enum
@ -44,6 +44,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao() knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao() document_chunk_dao = DocumentChunkDao()
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
@ -107,7 +108,6 @@ class KnowledgeService:
res.owner = space.owner res.owner = space.owner
res.gmt_created = space.gmt_created res.gmt_created = space.gmt_created
res.gmt_modified = space.gmt_modified res.gmt_modified = space.gmt_modified
res.owner = space.owner
res.context = space.context res.context = space.context
query = KnowledgeDocumentEntity(space=space.name) query = KnowledgeDocumentEntity(space=space.name)
doc_count = knowledge_document_dao.get_knowledge_documents_count(query) doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
@ -155,9 +155,10 @@ class KnowledgeService:
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids): def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
from langchain.text_splitter import ( from langchain.text_splitter import (
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
SpacyTextSplitter, SpacyTextSplitter,
@ -165,6 +166,7 @@ class KnowledgeService:
# import langchain is very very slow!!! # import langchain is very very slow!!!
doc_ids = sync_request.doc_ids
for doc_id in doc_ids: for doc_id in doc_ids:
query = KnowledgeDocumentEntity( query = KnowledgeDocumentEntity(
id=doc_id, id=doc_id,
@ -190,24 +192,43 @@ class KnowledgeService:
if space_context is None if space_context is None
else int(space_context["embedding"]["chunk_overlap"]) else int(space_context["embedding"]["chunk_overlap"])
) )
if sync_request.chunk_size:
chunk_size = sync_request.chunk_size
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
if CFG.LANGUAGE == "en": if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
length_function=len, length_function=len,
) )
else: else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multiple separators"
)
try: try:
separator = "\n\n" if not separators else separators[0]
text_splitter = SpacyTextSplitter( text_splitter = SpacyTextSplitter(
separator=separator,
pipeline="zh_core_web_sm", pipeline="zh_core_web_sm",
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
except Exception: except Exception:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
) )
if sync_request.pre_separator:
logger.info(f"Use preseparator, {sync_request.pre_separator}")
text_splitter = PreTextSplitter(
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
embedding_factory = CFG.SYSTEM_APP.get_component( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )

View File

@ -1,6 +1,6 @@
"""ElevenLabs speech module""" """ElevenLabs speech module"""
import os import os
import logging
import requests import requests
from pilot.configs.config import Config from pilot.configs.config import Config
@ -8,6 +8,8 @@ from pilot.speech.base import VoiceBase
PLACEHOLDERS = {"your-voice-id"} PLACEHOLDERS = {"your-voice-id"}
logger = logging.getLogger(__name__)
class ElevenLabsSpeech(VoiceBase): class ElevenLabsSpeech(VoiceBase):
"""ElevenLabs speech class""" """ElevenLabs speech class"""
@ -68,7 +70,6 @@ class ElevenLabsSpeech(VoiceBase):
Returns: Returns:
bool: True if the request was successful, False otherwise bool: True if the request was successful, False otherwise
""" """
from pilot.logs import logger
from playsound import playsound from playsound import playsound
tts_url = ( tts_url = (

View File

@ -1,5 +1,6 @@
import json import json
import uuid import uuid
import logging
from pilot.common.schema import DBType from pilot.common.schema import DBType
from pilot.component import SystemApp from pilot.component import SystemApp
@ -7,16 +8,14 @@ from pilot.configs.config import Config
from pilot.configs.model_config import ( from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
LOGDIR,
) )
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory from pilot.scene.chat_factory import ChatFactory
from pilot.summary.rdbms_db_summary import RdbmsSummary from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
chat_factory = ChatFactory() chat_factory = ChatFactory()

View File

@ -1,6 +1,5 @@
from .utils import ( from .utils import (
get_gpu_memory, get_gpu_memory,
build_logger,
StreamToLogger, StreamToLogger,
disable_torch_init, disable_torch_init,
pretty_print_semaphore, pretty_print_semaphore,

View File

@ -5,6 +5,8 @@ from dataclasses import is_dataclass, asdict
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__)
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]: def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect import typing_inspect
@ -21,7 +23,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
raise TypeError("Return type must be annotated in the decorated function.") raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type) actual_dataclass = _extract_dataclass_from_generic(return_type)
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}") logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
if not actual_dataclass: if not actual_dataclass:
actual_dataclass = return_type actual_dataclass = return_type
sig = signature(func) sig = signature(func)
@ -55,7 +57,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
else: # For GET, DELETE, etc. else: # For GET, DELETE, etc.
request_params["params"] = request_data request_params["params"] = request_data
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}") logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
return return_type, actual_dataclass, request_params return return_type, actual_dataclass, request_params

View File

@ -20,7 +20,7 @@ def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO") return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging(logging_level=None, logger_name: str = None): def setup_logging_level(logging_level=None, logger_name: str = None):
if not logging_level: if not logging_level:
logging_level = _get_logging_level() logging_level = _get_logging_level()
if type(logging_level) is str: if type(logging_level) is str:
@ -32,6 +32,19 @@ def setup_logging(logging_level=None, logger_name: str = None):
logging.basicConfig(level=logging_level, encoding="utf-8") logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename)
try:
import coloredlogs
color_level = logging_level if logging_level else "INFO"
coloredlogs.install(level=color_level, logger=logger)
except ImportError:
pass
def get_gpu_memory(max_gpus=None): def get_gpu_memory(max_gpus=None):
import torch import torch
@ -53,7 +66,7 @@ def get_gpu_memory(max_gpus=None):
return gpu_memory return gpu_memory
def build_logger(logger_name, logger_filename): def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
global handler global handler
formatter = logging.Formatter( formatter = logging.Formatter(
@ -63,7 +76,7 @@ def build_logger(logger_name, logger_filename):
# Set the format of root handlers # Set the format of root handlers
if not logging.getLogger().handlers: if not logging.getLogger().handlers:
setup_logging() setup_logging_level(logging_level=logging_level)
logging.getLogger().handlers[0].setFormatter(formatter) logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers # Redirect stdout and stderr to loggers
@ -78,7 +91,7 @@ def build_logger(logger_name, logger_filename):
# sys.stderr = sl # sys.stderr = sl
# Add a file handler for all loggers # Add a file handler for all loggers
if handler is None: if handler is None and logger_filename:
os.makedirs(LOGDIR, exist_ok=True) os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename) filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler( handler = logging.handlers.TimedRotatingFileHandler(
@ -89,11 +102,9 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items(): for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger): if isinstance(item, logging.Logger):
item.addHandler(handler) item.addHandler(handler)
setup_logging()
# Get logger # Get logger
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
setup_logging(logger_name=logger_name) setup_logging_level(logging_level=logging_level, logger_name=logger_name)
return logger return logger

View File

@ -1,11 +1,13 @@
import os import os
import logging
from typing import Any from typing import Any
from chromadb.config import Settings from chromadb.config import Settings
from chromadb import PersistentClient from chromadb import PersistentClient
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class ChromaStore(VectorStoreBase): class ChromaStore(VectorStoreBase):
"""chroma database""" """chroma database"""

View File

@ -1,12 +1,13 @@
from __future__ import annotations from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import Collection, DataType, connections, utility from pymilvus import Collection, DataType, connections, utility
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class MilvusStore(VectorStoreBase): class MilvusStore(VectorStoreBase):
"""Milvus database""" """Milvus database"""

View File

@ -1,5 +1,6 @@
import os import os
import json import json
import logging
import weaviate import weaviate
from langchain.schema import Document from langchain.schema import Document
from langchain.vectorstores import Weaviate from langchain.vectorstores import Weaviate
@ -7,9 +8,9 @@ from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()

View File

@ -287,6 +287,7 @@ def core_requires():
] ]
setup_spec.extras["framework"] = [ setup_spec.extras["framework"] = [
"coloredlogs",
"httpx", "httpx",
"sqlparse==0.4.4", "sqlparse==0.4.4",
"seaborn", "seaborn",