mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat(ChatKnowledge): Add custom text separators and refactor log configuration
This commit is contained in:
parent
20bdddec51
commit
5dfe611478
@ -55,6 +55,8 @@ EMBEDDING_MODEL=text2vec
|
||||
#EMBEDDING_MODEL=bge-large-zh
|
||||
KNOWLEDGE_CHUNK_SIZE=500
|
||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
# Control whether to display the source document of knowledge on the front end.
|
||||
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
@ -154,4 +156,6 @@ SUMMARY_CONFIG=FAST
|
||||
#** LOG **#
|
||||
#*******************************************************************#
|
||||
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
|
||||
DBGPT_LOG_LEVEL=INFO
|
||||
DBGPT_LOG_LEVEL=INFO
|
||||
# LOG dir, default: ./logs
|
||||
#DBGPT_LOG_DIR=
|
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
|
||||
from colorama import Fore
|
||||
from regex import regex
|
||||
@ -11,9 +12,9 @@ from pilot.json_utils.json_fix_general import (
|
||||
balance_braces,
|
||||
fix_invalid_escape,
|
||||
)
|
||||
from pilot.logs import logger
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
@ -2,14 +2,15 @@
|
||||
import io
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
@ -13,11 +13,10 @@ import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
|
@ -1,13 +1,12 @@
|
||||
import logging
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
|
@ -8,6 +8,7 @@ import zipfile
|
||||
import requests
|
||||
import threading
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
@ -17,7 +18,8 @@ import requests
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
@ -32,7 +32,7 @@ class Config(metaclass=Singleton):
|
||||
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||
)
|
||||
# User agent header to use when making HTTP requests
|
||||
# Some websites might just completely deny request with an error code if
|
||||
@ -64,7 +64,7 @@ class Config(metaclass=Singleton):
|
||||
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
||||
self.milvus_secure = os.getenv("MILVUS_SECURE", "False").lower() == "true"
|
||||
|
||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||
@ -98,7 +98,7 @@ class Config(metaclass=Singleton):
|
||||
self.disabled_command_categories = []
|
||||
|
||||
self.execute_local_commands = (
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||
)
|
||||
### message stor file
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
@ -107,7 +107,7 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
|
||||
|
||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||
|
||||
@ -124,10 +124,10 @@ class Config(metaclass=Singleton):
|
||||
self.plugins_denylist = []
|
||||
### Native SQL Execution Capability Control Configuration
|
||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
|
||||
)
|
||||
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
### default Local database connection configuration
|
||||
@ -170,8 +170,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True"
|
||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
|
||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
|
||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||
self.IS_LOAD_8BIT = False
|
||||
# In order to be compatible with the new and old model parameter design
|
||||
@ -187,7 +187,9 @@ class Config(metaclass=Singleton):
|
||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||
)
|
||||
### Control whether to display the source document of knowledge on the front end.
|
||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
|
||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
|
||||
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
### SUMMARY_CONFIG Configuration
|
||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||
|
@ -3,13 +3,12 @@
|
||||
|
||||
import os
|
||||
|
||||
# import nltk
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
||||
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
|
||||
|
||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||
|
@ -1,10 +1,11 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
@ -1,5 +1,12 @@
|
||||
from pilot.embedding_engine.source_embedding import SourceEmbedding, register
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||
|
||||
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
||||
__all__ = [
|
||||
"SourceEmbedding",
|
||||
"register",
|
||||
"EmbeddingEngine",
|
||||
"KnowledgeType",
|
||||
"PreTextSplitter",
|
||||
]
|
||||
|
30
pilot/embedding_engine/pre_text_splitter.py
Normal file
30
pilot/embedding_engine/pre_text_splitter.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Iterable, List
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
|
||||
def _single_document_split(
|
||||
document: Document, pre_separator: str
|
||||
) -> Iterable[Document]:
|
||||
page_content = document.page_content
|
||||
for i, content in enumerate(page_content.split(pre_separator)):
|
||||
metadata = document.metadata.copy()
|
||||
if "source" in metadata:
|
||||
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
|
||||
yield Document(page_content=content, metadata=metadata)
|
||||
|
||||
|
||||
class PreTextSplitter(TextSplitter):
|
||||
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
|
||||
self.pre_separator = pre_separator
|
||||
self._impl = text_splitter_impl
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._impl.split_text(text)
|
||||
|
||||
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
|
||||
def generator() -> Iterable[Document]:
|
||||
for doc in documents:
|
||||
yield from _single_document_split(doc, pre_separator=self.pre_separator)
|
||||
|
||||
return self._impl.split_documents(generator())
|
@ -5,12 +5,13 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.json_utils.utilities import extract_char_position
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
@ -3,12 +3,14 @@ import json
|
||||
import os.path
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from jsonschema import Draft7Validator
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Callable, Type
|
||||
from functools import cache
|
||||
@ -19,8 +20,8 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@ -13,6 +13,9 @@ from pilot.utils.api_utils import (
|
||||
_api_remote as api_remote,
|
||||
_sync_api_remote as sync_api_remote,
|
||||
)
|
||||
from pilot.utils.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseModelController(BaseComponent, ABC):
|
||||
@ -59,7 +62,7 @@ class LocalModelController(BaseModelController):
|
||||
async def get_all_instances(
|
||||
self, model_name: str = None, healthy_only: bool = False
|
||||
) -> List[ModelInstance]:
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||
)
|
||||
if not model_name:
|
||||
@ -178,6 +181,13 @@ def run_model_controller():
|
||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||
ModelControllerParameters, env_prefix=env_prefix
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=controller_params.log_level,
|
||||
logger_filename="dbgpt_model_controller.log",
|
||||
)
|
||||
|
||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||
|
||||
|
||||
|
@ -120,7 +120,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
text=output, error_code=0, model_context=model_context
|
||||
)
|
||||
yield model_output
|
||||
print(f"\n\nfull stream output:\n{previous_response}")
|
||||
print(
|
||||
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
|
@ -36,6 +36,7 @@ from pilot.utils.parameter_utils import (
|
||||
ParameterDescription,
|
||||
_dict_to_command_args,
|
||||
)
|
||||
from pilot.utils.utils import setup_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -885,6 +886,12 @@ def run_worker_manager(
|
||||
model_name=model_name, model_path=model_path, standalone=standalone, port=port
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
"pilot",
|
||||
logging_level=worker_params.log_level,
|
||||
logger_filename="dbgpt_model_worker_manager.log",
|
||||
)
|
||||
|
||||
embedded_mod = True
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
if not app:
|
||||
|
@ -3,11 +3,13 @@ Fork from text-generation-webui https://github.com/oobabooga/text-generation-web
|
||||
"""
|
||||
import re
|
||||
from typing import Dict
|
||||
import logging
|
||||
import torch
|
||||
import llama_cpp
|
||||
|
||||
from pilot.model.parameter import LlamaCppModelParameters
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if torch.cuda.is_available() and not torch.version.hip:
|
||||
try:
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from typing import Optional, Dict
|
||||
|
||||
import logging
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||
from pilot.model.parameter import (
|
||||
@ -12,7 +13,8 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.utils import get_gpu_memory
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||
from pilot.logs import logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||
|
@ -31,6 +31,21 @@ class ModelControllerParameters(BaseParameters):
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model Controller in background"}
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
"FATAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"NOTSET",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -85,6 +100,22 @@ class ModelWorkerParameters(BaseModelParameters):
|
||||
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
|
||||
)
|
||||
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
"FATAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"NOTSET",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||
|
@ -38,8 +38,6 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
from pilot.scene.message import OnceConversation
|
||||
@ -409,7 +407,7 @@ async def model_types(controller: BaseModelController = Depends(get_model_contro
|
||||
|
||||
|
||||
@router.get("/v1/model/supports")
|
||||
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||
async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||
logger.info(f"/controller/model/supports")
|
||||
try:
|
||||
models = await worker_manager.supported_models()
|
||||
|
@ -6,12 +6,11 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.openapi.api_view_model import (
|
||||
Result,
|
||||
@ -34,7 +33,8 @@ from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||
router = APIRouter()
|
||||
CFG = Config()
|
||||
CHAT_FACTORY = ChatFactory()
|
||||
logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||
|
@ -2,18 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.utils import build_logger
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseTye = Union[str, bytes, ModelOutput]
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
import datetime
|
||||
import traceback
|
||||
import warnings
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.component import ComponentType
|
||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||
@ -14,10 +14,10 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.scene.message import OnceConversation
|
||||
from pilot.utils import build_logger, get_or_create_event_loop
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pydantic import Extra
|
||||
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
CFG = Config()
|
||||
|
||||
|
@ -1,13 +1,12 @@
|
||||
from typing import List
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DashboardDataLoader:
|
||||
|
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import NamedTuple, List
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.scene.base import ChatScene
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class ChartItem(NamedTuple):
|
||||
showcase: str
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDashboardOutputParser(BaseOutputParser):
|
||||
|
@ -1,11 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
@ -17,7 +13,7 @@ class ExcelAnalyzeResponse(NamedTuple):
|
||||
display: str
|
||||
|
||||
|
||||
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatExcelOutputParser(BaseOutputParser):
|
||||
|
@ -1,11 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Dict, NamedTuple, List
|
||||
import pandas as pd
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
@ -17,7 +13,7 @@ class ExcelResponse(NamedTuple):
|
||||
plans: List
|
||||
|
||||
|
||||
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LearningExcelOutputParser(BaseOutputParser):
|
||||
|
@ -1,8 +1,7 @@
|
||||
import json
|
||||
from typing import Dict, NamedTuple
|
||||
from pilot.utils import build_logger
|
||||
import logging
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||
|
||||
@ -14,7 +13,7 @@ class SqlAction(NamedTuple):
|
||||
thoughts: Dict
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbChatOutputParser(BaseOutputParser):
|
||||
|
@ -1,8 +1,4 @@
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.utils import build_logger
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
@ -1,11 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, NamedTuple
|
||||
from pilot.utils import build_logger
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PluginAction(NamedTuple):
|
||||
|
@ -1,9 +1,7 @@
|
||||
from pilot.utils import build_logger
|
||||
import logging
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
@ -60,18 +60,25 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
async def stream_call(self):
|
||||
input_values = self.generate_input_values()
|
||||
# Source of knowledge file
|
||||
relations = input_values.get("relations")
|
||||
last_output = None
|
||||
async for output in super().stream_call():
|
||||
# Source of knowledge file
|
||||
relations = input_values.get("relations")
|
||||
if (
|
||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||
and type(relations) == list
|
||||
and len(relations) > 0
|
||||
and hasattr(output, "text")
|
||||
):
|
||||
output.text = output.text + "\trelations:" + ",".join(relations)
|
||||
last_output = output
|
||||
yield output
|
||||
|
||||
if (
|
||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||
and last_output
|
||||
and type(relations) == list
|
||||
and len(relations) > 0
|
||||
and hasattr(last_output, "text")
|
||||
):
|
||||
last_output.text = (
|
||||
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
|
||||
)
|
||||
yield last_output
|
||||
|
||||
def generate_input_values(self):
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
|
@ -1,9 +1,8 @@
|
||||
from pilot.utils import build_logger
|
||||
import logging
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
@ -1,9 +1,8 @@
|
||||
from pilot.utils import build_logger
|
||||
import logging
|
||||
from pilot.out_parser.base import BaseOutputParser, T
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
|
||||
|
||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NormalChatOutputParser(BaseOutputParser):
|
||||
|
@ -108,7 +108,7 @@ class WebWerverParameters(BaseParameters):
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default="INFO",
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
|
@ -33,7 +33,11 @@ from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route
|
||||
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||
from pilot.utils.utils import (
|
||||
setup_logging,
|
||||
_get_logging_level,
|
||||
logging_str_to_uvicorn_level,
|
||||
)
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
@ -111,7 +115,11 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
)
|
||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
|
||||
setup_logging(logging_level=param.log_level)
|
||||
if not param.log_level:
|
||||
param.log_level = _get_logging_level()
|
||||
setup_logging(
|
||||
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
|
||||
)
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
import functools
|
||||
|
||||
from pilot.configs.model_config import DATASETS_DIR
|
||||
|
||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||
API_ADDRESS: str = _DEFAULT_API_ADDRESS
|
||||
|
||||
logger = logging.getLogger("dbgpt_cli")
|
||||
|
||||
@ -20,33 +23,44 @@ logger = logging.getLogger("dbgpt_cli")
|
||||
def knowledge_cli_group(address: str):
|
||||
"""Knowledge command line tool"""
|
||||
global API_ADDRESS
|
||||
if address == _DEFAULT_API_ADDRESS:
|
||||
address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS)
|
||||
API_ADDRESS = address
|
||||
|
||||
|
||||
def add_knowledge_options(func):
|
||||
@click.option(
|
||||
"--space_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default="default",
|
||||
show_default=True,
|
||||
help="Your knowledge space name",
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@click.option(
|
||||
"--vector_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default="default",
|
||||
show_default=True,
|
||||
help="Your vector store name",
|
||||
)
|
||||
@add_knowledge_options
|
||||
@click.option(
|
||||
"--vector_store_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default="Chroma",
|
||||
show_default=True,
|
||||
help="Vector store type",
|
||||
help="Vector store type.",
|
||||
)
|
||||
@click.option(
|
||||
"--local_doc_dir",
|
||||
"--local_doc_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default=DATASETS_DIR,
|
||||
show_default=True,
|
||||
help="Your document directory",
|
||||
help="Your document directory or document file path.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip_wrong_doc",
|
||||
@ -54,31 +68,165 @@ def knowledge_cli_group(address: str):
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
show_default=True,
|
||||
help="Skip wrong document",
|
||||
help="Skip wrong document.",
|
||||
)
|
||||
@click.option(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Overwrite existing document(they has same name).",
|
||||
)
|
||||
@click.option(
|
||||
"--max_workers",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="The maximum number of threads that can be used to upload document",
|
||||
help="The maximum number of threads that can be used to upload document.",
|
||||
)
|
||||
@click.option(
|
||||
"--pre_separator",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Preseparator, this separator is used for pre-splitting before the document is "
|
||||
"actually split by the text splitter. Preseparator are not included in the vectorized text. ",
|
||||
)
|
||||
@click.option(
|
||||
"--separator",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="This is the document separator. Currently, only one separator is supported.",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum size of chunks to split.",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_overlap",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
help="Overlap in characters between chunks.",
|
||||
)
|
||||
def load(
|
||||
vector_name: str,
|
||||
space_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_dir: str,
|
||||
local_doc_path: str,
|
||||
skip_wrong_doc: bool,
|
||||
overwrite: bool,
|
||||
max_workers: int,
|
||||
pre_separator: str,
|
||||
separator: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
):
|
||||
"""Load your local knowledge to DB-GPT"""
|
||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
API_ADDRESS,
|
||||
vector_name,
|
||||
space_name,
|
||||
vector_store_type,
|
||||
local_doc_dir,
|
||||
local_doc_path,
|
||||
skip_wrong_doc,
|
||||
overwrite,
|
||||
max_workers,
|
||||
pre_separator,
|
||||
separator,
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
)
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@add_knowledge_options
|
||||
@click.option(
|
||||
"--doc_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The document name you want to delete. If doc_name is None, this command will delete the whole space.",
|
||||
)
|
||||
@click.option(
|
||||
"-y",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Confirm your choice",
|
||||
)
|
||||
def delete(space_name: str, doc_name: str, y: bool):
|
||||
"""Delete your knowledge space or document in space"""
|
||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_delete
|
||||
|
||||
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
|
||||
|
||||
|
||||
@knowledge_cli_group.command()
|
||||
@click.option(
|
||||
"--space_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Your knowledge space name. If None, list all spaces",
|
||||
)
|
||||
@click.option(
|
||||
"--doc_id",
|
||||
required=False,
|
||||
type=int,
|
||||
default=None,
|
||||
show_default=True,
|
||||
help="Your document id in knowledge space. If Not None, list all chunks in current document",
|
||||
)
|
||||
@click.option(
|
||||
"--page",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1,
|
||||
show_default=True,
|
||||
help="The page for every query",
|
||||
)
|
||||
@click.option(
|
||||
"--page_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=20,
|
||||
show_default=True,
|
||||
help="The page size for every query",
|
||||
)
|
||||
@click.option(
|
||||
"--show_content",
|
||||
required=False,
|
||||
type=bool,
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Query the document content of chunks",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
required=False,
|
||||
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||
default="text",
|
||||
help="The output format",
|
||||
)
|
||||
def list(
|
||||
space_name: str,
|
||||
doc_id: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
show_content: bool,
|
||||
output: str,
|
||||
):
|
||||
"""List knowledge space"""
|
||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_list
|
||||
|
||||
knowledge_list(
|
||||
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output
|
||||
)
|
||||
|
@ -62,6 +62,9 @@ class KnowledgeApiClient(ApiClient):
|
||||
else:
|
||||
raise e
|
||||
|
||||
def space_delete(self, request: KnowledgeSpaceRequest):
|
||||
return self._post("/knowledge/space/delete", data=request)
|
||||
|
||||
def space_list(self, request: KnowledgeSpaceRequest):
|
||||
return self._post("/knowledge/space/list", data=request)
|
||||
|
||||
@ -69,6 +72,10 @@ class KnowledgeApiClient(ApiClient):
|
||||
url = f"/knowledge/{space_name}/document/add"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def document_delete(self, space_name: str, request: KnowledgeDocumentRequest):
|
||||
url = f"/knowledge/{space_name}/document/delete"
|
||||
return self._post(url, data=request)
|
||||
|
||||
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
|
||||
url = f"/knowledge/{space_name}/document/list"
|
||||
return self._post(url, data=query_request)
|
||||
@ -97,15 +104,20 @@ class KnowledgeApiClient(ApiClient):
|
||||
|
||||
def knowledge_init(
|
||||
api_address: str,
|
||||
vector_name: str,
|
||||
space_name: str,
|
||||
vector_store_type: str,
|
||||
local_doc_dir: str,
|
||||
local_doc_path: str,
|
||||
skip_wrong_doc: bool,
|
||||
max_workers: int = None,
|
||||
overwrite: bool,
|
||||
max_workers: int,
|
||||
pre_separator: str,
|
||||
separator: str,
|
||||
chunk_size: int,
|
||||
chunk_overlap: int,
|
||||
):
|
||||
client = KnowledgeApiClient(api_address)
|
||||
space = KnowledgeSpaceRequest()
|
||||
space.name = vector_name
|
||||
space.name = space_name
|
||||
space.desc = "DB-GPT cli"
|
||||
space.vector_type = vector_store_type
|
||||
space.owner = "DB-GPT"
|
||||
@ -124,24 +136,260 @@ def knowledge_init(
|
||||
def upload(filename: str):
|
||||
try:
|
||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
|
||||
doc_id = None
|
||||
try:
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
except Exception as ex:
|
||||
if overwrite and "have already named" in str(ex):
|
||||
logger.warn(
|
||||
f"Document {filename} already exist in space {space.name}, overwrite it"
|
||||
)
|
||||
client.document_delete(
|
||||
space.name, KnowledgeDocumentRequest(doc_name=filename)
|
||||
)
|
||||
doc_id = client.document_upload(
|
||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||
)
|
||||
else:
|
||||
raise ex
|
||||
sync_req = DocumentSyncRequest(doc_ids=[doc_id])
|
||||
if pre_separator:
|
||||
sync_req.pre_separator = pre_separator
|
||||
if separator:
|
||||
sync_req.separators = [separator]
|
||||
if chunk_size:
|
||||
sync_req.chunk_size = chunk_size
|
||||
if chunk_overlap:
|
||||
sync_req.chunk_overlap = chunk_overlap
|
||||
|
||||
client.document_sync(space.name, sync_req)
|
||||
return doc_id
|
||||
except Exception as e:
|
||||
if skip_wrong_doc:
|
||||
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not os.path.exists(local_doc_path):
|
||||
raise Exception(f"{local_doc_path} not exists")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
tasks = []
|
||||
for root, _, files in os.walk(local_doc_dir, topdown=False):
|
||||
for file in files:
|
||||
filename = os.path.join(root, file)
|
||||
tasks.append(pool.submit(upload, filename))
|
||||
file_names = []
|
||||
if os.path.isdir(local_doc_path):
|
||||
for root, _, files in os.walk(local_doc_path, topdown=False):
|
||||
for file in files:
|
||||
file_names.append(os.path.join(root, file))
|
||||
else:
|
||||
# Single file
|
||||
file_names.append(local_doc_path)
|
||||
|
||||
[tasks.append(pool.submit(upload, filename)) for filename in file_names]
|
||||
|
||||
doc_ids = [r.result() for r in as_completed(tasks)]
|
||||
doc_ids = list(filter(lambda x: x, doc_ids))
|
||||
if not doc_ids:
|
||||
logger.warn("Warning: no document to sync")
|
||||
return
|
||||
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
class _KnowledgeVisualizer:
|
||||
def __init__(self, api_address: str, out_format: str):
|
||||
self.client = KnowledgeApiClient(api_address)
|
||||
self.out_format = out_format
|
||||
self.out_kwargs = {}
|
||||
if out_format == "json":
|
||||
self.out_kwargs["ensure_ascii"] = False
|
||||
|
||||
def print_table(self, table):
|
||||
print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs))
|
||||
|
||||
def list_spaces(self):
|
||||
spaces = self.client.space_list(KnowledgeSpaceRequest())
|
||||
table = PrettyTable(
|
||||
["Space ID", "Space Name", "Vector Type", "Owner", "Description"],
|
||||
title="All knowledge spaces",
|
||||
)
|
||||
for sp in spaces:
|
||||
context = sp.get("context")
|
||||
table.add_row(
|
||||
[
|
||||
sp.get("id"),
|
||||
sp.get("name"),
|
||||
sp.get("vector_type"),
|
||||
sp.get("owner"),
|
||||
sp.get("desc"),
|
||||
]
|
||||
)
|
||||
self.print_table(table)
|
||||
|
||||
def list_documents(self, space_name: str, page: int, page_size: int):
|
||||
space_data = self.client.document_list(
|
||||
space_name, DocumentQueryRequest(page=page, page_size=page_size)
|
||||
)
|
||||
|
||||
space_table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Total Documents",
|
||||
"Current Page",
|
||||
"Current Size",
|
||||
"Page Size",
|
||||
],
|
||||
title=f"Space {space_name} description",
|
||||
)
|
||||
space_table.add_row(
|
||||
[space_name, space_data["total"], page, len(space_data["data"]), page_size]
|
||||
)
|
||||
|
||||
table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Document ID",
|
||||
"Document Name",
|
||||
"Type",
|
||||
"Chunks",
|
||||
"Last Sync",
|
||||
"Status",
|
||||
"Result",
|
||||
],
|
||||
title=f"Documents of space {space_name}",
|
||||
)
|
||||
for doc in space_data["data"]:
|
||||
table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc.get("id"),
|
||||
doc.get("doc_name"),
|
||||
doc.get("doc_type"),
|
||||
doc.get("chunk_size"),
|
||||
doc.get("last_sync"),
|
||||
doc.get("status"),
|
||||
doc.get("result"),
|
||||
]
|
||||
)
|
||||
if self.out_format == "text":
|
||||
self.print_table(space_table)
|
||||
print("")
|
||||
self.print_table(table)
|
||||
|
||||
def list_chunks(
|
||||
self,
|
||||
space_name: str,
|
||||
doc_id: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
show_content: bool,
|
||||
):
|
||||
doc_data = self.client.chunk_list(
|
||||
space_name,
|
||||
ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size),
|
||||
)
|
||||
|
||||
doc_table = PrettyTable(
|
||||
[
|
||||
"Space Name",
|
||||
"Document ID",
|
||||
"Total Chunks",
|
||||
"Current Page",
|
||||
"Current Size",
|
||||
"Page Size",
|
||||
],
|
||||
title=f"Document {doc_id} in {space_name} description",
|
||||
)
|
||||
doc_table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc_id,
|
||||
doc_data["total"],
|
||||
page,
|
||||
len(doc_data["data"]),
|
||||
page_size,
|
||||
]
|
||||
)
|
||||
|
||||
table = PrettyTable(
|
||||
["Space Name", "Document ID", "Document Name", "Content", "Meta Data"],
|
||||
title=f"chunks of document id {doc_id} in space {space_name}",
|
||||
)
|
||||
for chunk in doc_data["data"]:
|
||||
table.add_row(
|
||||
[
|
||||
space_name,
|
||||
doc_id,
|
||||
chunk.get("doc_name"),
|
||||
chunk.get("content") if show_content else "[Hidden]",
|
||||
chunk.get("meta_info"),
|
||||
]
|
||||
)
|
||||
if self.out_format == "text":
|
||||
self.print_table(doc_table)
|
||||
print("")
|
||||
self.print_table(table)
|
||||
|
||||
|
||||
def knowledge_list(
|
||||
api_address: str,
|
||||
space_name: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
doc_id: int,
|
||||
show_content: bool,
|
||||
out_format: str,
|
||||
):
|
||||
visualizer = _KnowledgeVisualizer(api_address, out_format)
|
||||
if not space_name:
|
||||
visualizer.list_spaces()
|
||||
elif not doc_id:
|
||||
visualizer.list_documents(space_name, page, page_size)
|
||||
else:
|
||||
visualizer.list_chunks(space_name, doc_id, page, page_size, show_content)
|
||||
|
||||
|
||||
def knowledge_delete(
|
||||
api_address: str, space_name: str, doc_name: str, confirm: bool = False
|
||||
):
|
||||
client = KnowledgeApiClient(api_address)
|
||||
space = KnowledgeSpaceRequest()
|
||||
space.name = space_name
|
||||
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
|
||||
if not space_list:
|
||||
raise Exception(f"No knowledge space name {space_name}")
|
||||
|
||||
if not doc_name:
|
||||
if not confirm:
|
||||
# Confirm by user
|
||||
user_input = (
|
||||
input(
|
||||
f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if user_input != "yes":
|
||||
logger.warn("Delete operation cancelled.")
|
||||
return
|
||||
client.space_delete(space)
|
||||
logger.info("Delete the whole knowledge space successfully!")
|
||||
else:
|
||||
if not confirm:
|
||||
# Confirm by user
|
||||
user_input = (
|
||||
input(
|
||||
f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: "
|
||||
)
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if user_input != "yes":
|
||||
logger.warn("Delete operation cancelled.")
|
||||
return
|
||||
client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name))
|
||||
logger.info(
|
||||
f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!"
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, Form
|
||||
|
||||
@ -27,6 +28,8 @@ from pilot.server.knowledge.request.request import (
|
||||
|
||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
router = APIRouter()
|
||||
|
||||
@ -159,10 +162,10 @@ async def document_upload(
|
||||
|
||||
@router.post("/knowledge/{space_name}/document/sync")
|
||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||
print(f"Received params: {space_name}, {request}")
|
||||
logger.info(f"Received params: {space_name}, {request}")
|
||||
try:
|
||||
knowledge_space_service.sync_knowledge_document(
|
||||
space_name=space_name, doc_ids=request.doc_ids
|
||||
space_name=space_name, sync_request=request
|
||||
)
|
||||
return Result.succ([])
|
||||
except Exception as e:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from fastapi import UploadFile
|
||||
@ -54,10 +54,25 @@ class DocumentQueryRequest(BaseModel):
|
||||
|
||||
|
||||
class DocumentSyncRequest(BaseModel):
|
||||
"""doc_ids: doc ids"""
|
||||
"""Sync request"""
|
||||
|
||||
"""doc_ids: doc ids"""
|
||||
doc_ids: List
|
||||
|
||||
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
|
||||
Preseparator are not included in the vectorized text.
|
||||
"""
|
||||
pre_separator: Optional[str] = None
|
||||
|
||||
"""Custom separators"""
|
||||
separators: Optional[List[str]] = None
|
||||
|
||||
"""Custom chunk size"""
|
||||
chunk_size: Optional[int] = None
|
||||
|
||||
"""Custom chunk overlap"""
|
||||
chunk_overlap: Optional[int] = None
|
||||
|
||||
|
||||
class ChunkQueryRequest(BaseModel):
|
||||
"""id: id"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
import json
|
||||
import threading
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
@ -12,7 +12,6 @@ from pilot.configs.model_config import (
|
||||
from pilot.component import ComponentType
|
||||
from pilot.utils.executor_utils import ExecutorFactory
|
||||
|
||||
from pilot.logs import logger
|
||||
from pilot.server.knowledge.chunk_db import (
|
||||
DocumentChunkEntity,
|
||||
DocumentChunkDao,
|
||||
@ -31,6 +30,7 @@ from pilot.server.knowledge.request.request import (
|
||||
DocumentQueryRequest,
|
||||
ChunkQueryRequest,
|
||||
SpaceArgumentRequest,
|
||||
DocumentSyncRequest,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
@ -44,6 +44,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
|
||||
knowledge_document_dao = KnowledgeDocumentDao()
|
||||
document_chunk_dao = DocumentChunkDao()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
@ -107,7 +108,6 @@ class KnowledgeService:
|
||||
res.owner = space.owner
|
||||
res.gmt_created = space.gmt_created
|
||||
res.gmt_modified = space.gmt_modified
|
||||
res.owner = space.owner
|
||||
res.context = space.context
|
||||
query = KnowledgeDocumentEntity(space=space.name)
|
||||
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||
@ -155,9 +155,10 @@ class KnowledgeService:
|
||||
|
||||
"""sync knowledge document chunk into vector store"""
|
||||
|
||||
def sync_knowledge_document(self, space_name, doc_ids):
|
||||
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||
from langchain.text_splitter import (
|
||||
RecursiveCharacterTextSplitter,
|
||||
SpacyTextSplitter,
|
||||
@ -165,6 +166,7 @@ class KnowledgeService:
|
||||
|
||||
# import langchain is very very slow!!!
|
||||
|
||||
doc_ids = sync_request.doc_ids
|
||||
for doc_id in doc_ids:
|
||||
query = KnowledgeDocumentEntity(
|
||||
id=doc_id,
|
||||
@ -190,24 +192,43 @@ class KnowledgeService:
|
||||
if space_context is None
|
||||
else int(space_context["embedding"]["chunk_overlap"])
|
||||
)
|
||||
if sync_request.chunk_size:
|
||||
chunk_size = sync_request.chunk_size
|
||||
if sync_request.chunk_overlap:
|
||||
chunk_overlap = sync_request.chunk_overlap
|
||||
separators = sync_request.separators or None
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
if separators and len(separators) > 1:
|
||||
raise ValueError(
|
||||
"SpacyTextSplitter do not support multiple separators"
|
||||
)
|
||||
try:
|
||||
separator = "\n\n" if not separators else separators[0]
|
||||
text_splitter = SpacyTextSplitter(
|
||||
separator=separator,
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception:
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=separators,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
if sync_request.pre_separator:
|
||||
logger.info(f"Use preseparator, {sync_request.pre_separator}")
|
||||
text_splitter = PreTextSplitter(
|
||||
pre_separator=sync_request.pre_separator,
|
||||
text_splitter_impl=text_splitter,
|
||||
)
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""ElevenLabs speech module"""
|
||||
import os
|
||||
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from pilot.configs.config import Config
|
||||
@ -8,6 +8,8 @@ from pilot.speech.base import VoiceBase
|
||||
|
||||
PLACEHOLDERS = {"your-voice-id"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ElevenLabsSpeech(VoiceBase):
|
||||
"""ElevenLabs speech class"""
|
||||
@ -68,7 +70,6 @@ class ElevenLabsSpeech(VoiceBase):
|
||||
Returns:
|
||||
bool: True if the request was successful, False otherwise
|
||||
"""
|
||||
from pilot.logs import logger
|
||||
from playsound import playsound
|
||||
|
||||
tts_url = (
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from pilot.common.schema import DBType
|
||||
from pilot.component import SystemApp
|
||||
@ -7,16 +8,14 @@ from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
EMBEDDING_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.chat_factory import ChatFactory
|
||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||
from pilot.utils import build_logger
|
||||
|
||||
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
chat_factory = ChatFactory()
|
||||
|
@ -1,6 +1,5 @@
|
||||
from .utils import (
|
||||
get_gpu_memory,
|
||||
build_logger,
|
||||
StreamToLogger,
|
||||
disable_torch_init,
|
||||
pretty_print_semaphore,
|
||||
|
@ -5,6 +5,8 @@ from dataclasses import is_dataclass, asdict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||
import typing_inspect
|
||||
@ -21,7 +23,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
raise TypeError("Return type must be annotated in the decorated function.")
|
||||
|
||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||
if not actual_dataclass:
|
||||
actual_dataclass = return_type
|
||||
sig = signature(func)
|
||||
@ -55,7 +57,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
||||
else: # For GET, DELETE, etc.
|
||||
request_params["params"] = request_data
|
||||
|
||||
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||
return return_type, actual_dataclass, request_params
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ def _get_logging_level() -> str:
|
||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
|
||||
def setup_logging(logging_level=None, logger_name: str = None):
|
||||
def setup_logging_level(logging_level=None, logger_name: str = None):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
if type(logging_level) is str:
|
||||
@ -32,6 +32,19 @@ def setup_logging(logging_level=None, logger_name: str = None):
|
||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||
|
||||
|
||||
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
|
||||
if not logging_level:
|
||||
logging_level = _get_logging_level()
|
||||
logger = _build_logger(logger_name, logging_level, logger_filename)
|
||||
try:
|
||||
import coloredlogs
|
||||
|
||||
color_level = logging_level if logging_level else "INFO"
|
||||
coloredlogs.install(level=color_level, logger=logger)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_gpu_memory(max_gpus=None):
|
||||
import torch
|
||||
|
||||
@ -53,7 +66,7 @@ def get_gpu_memory(max_gpus=None):
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def build_logger(logger_name, logger_filename):
|
||||
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||
global handler
|
||||
|
||||
formatter = logging.Formatter(
|
||||
@ -63,7 +76,7 @@ def build_logger(logger_name, logger_filename):
|
||||
|
||||
# Set the format of root handlers
|
||||
if not logging.getLogger().handlers:
|
||||
setup_logging()
|
||||
setup_logging_level(logging_level=logging_level)
|
||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||
|
||||
# Redirect stdout and stderr to loggers
|
||||
@ -78,7 +91,7 @@ def build_logger(logger_name, logger_filename):
|
||||
# sys.stderr = sl
|
||||
|
||||
# Add a file handler for all loggers
|
||||
if handler is None:
|
||||
if handler is None and logger_filename:
|
||||
os.makedirs(LOGDIR, exist_ok=True)
|
||||
filename = os.path.join(LOGDIR, logger_filename)
|
||||
handler = logging.handlers.TimedRotatingFileHandler(
|
||||
@ -89,11 +102,9 @@ def build_logger(logger_name, logger_filename):
|
||||
for name, item in logging.root.manager.loggerDict.items():
|
||||
if isinstance(item, logging.Logger):
|
||||
item.addHandler(handler)
|
||||
setup_logging()
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger(logger_name)
|
||||
setup_logging(logger_name=logger_name)
|
||||
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
|
||||
|
||||
return logger
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb import PersistentClient
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaStore(VectorStoreBase):
|
||||
"""chroma database"""
|
||||
|
@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from pymilvus import Collection, DataType, connections, utility
|
||||
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MilvusStore(VectorStoreBase):
|
||||
"""Milvus database"""
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import weaviate
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Weaviate
|
||||
@ -7,9 +8,9 @@ from weaviate.exceptions import WeaviateBaseError
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user