feat(ChatKnowledge): Add custom text separators and refactor log configuration

This commit is contained in:
FangYin Cheng 2023-09-28 11:54:58 +08:00
parent 20bdddec51
commit 5dfe611478
52 changed files with 705 additions and 158 deletions

View File

@ -55,6 +55,8 @@ EMBEDDING_MODEL=text2vec
#EMBEDDING_MODEL=bge-large-zh
KNOWLEDGE_CHUNK_SIZE=500
KNOWLEDGE_SEARCH_TOP_SIZE=5
# Control whether to display the source document of knowledge on the front end.
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
# EMBEDDING_MODEL=all-MiniLM-L6-v2
@ -154,4 +156,6 @@ SUMMARY_CONFIG=FAST
#** LOG **#
#*******************************************************************#
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
DBGPT_LOG_LEVEL=INFO
DBGPT_LOG_LEVEL=INFO
# LOG dir, default: ./logs
#DBGPT_LOG_DIR=

View File

@ -1,6 +1,7 @@
import contextlib
import json
from typing import Any, Dict
import logging
from colorama import Fore
from regex import regex
@ -11,9 +12,9 @@ from pilot.json_utils.json_fix_general import (
balance_braces,
fix_invalid_escape,
)
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -2,14 +2,15 @@
import io
import uuid
from base64 import b64decode
import logging
import requests
from PIL import Image
from pilot.commands.command_mange import command
from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -1,5 +1,5 @@
import logging
from pandas import DataFrame
from pilot.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
@ -13,11 +13,10 @@ import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
logger = logging.getLogger(__name__)
static_message_img_path = os.path.join(os.getcwd(), "message/img")

View File

@ -1,14 +1,13 @@
import logging
import pandas as pd
from pandas import DataFrame
from pilot.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
logger = logging.getLogger(__name__)
@command(

View File

@ -1,13 +1,12 @@
import logging
import pandas as pd
from pandas import DataFrame
from pilot.commands.command_mange import command
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
logger = logging.getLogger(__name__)
@command(

View File

@ -8,6 +8,7 @@ import zipfile
import requests
import threading
import datetime
import logging
from pathlib import Path
from typing import List, TYPE_CHECKING
from urllib.parse import urlparse
@ -17,7 +18,8 @@ import requests
from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate

View File

@ -32,7 +32,7 @@ class Config(metaclass=Singleton):
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
)
# User agent header to use when making HTTP requests
# Some websites might just completely deny request with an error code if
@ -64,7 +64,7 @@ class Config(metaclass=Singleton):
self.milvus_username = os.getenv("MILVUS_USERNAME")
self.milvus_password = os.getenv("MILVUS_PASSWORD")
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
self.milvus_secure = os.getenv("MILVUS_SECURE", "False").lower() == "true"
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
self.exit_key = os.getenv("EXIT_KEY", "n")
@ -98,7 +98,7 @@ class Config(metaclass=Singleton):
self.disabled_command_categories = []
self.execute_local_commands = (
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
)
### message stor file
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
@ -107,7 +107,7 @@ class Config(metaclass=Singleton):
self.plugins: List["AutoGPTPluginTemplate"] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
@ -124,10 +124,10 @@ class Config(metaclass=Singleton):
self.plugins_denylist = []
### Native SQL Execution Capability Control Configuration
self.NATIVE_SQL_CAN_RUN_DDL = (
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
)
self.NATIVE_SQL_CAN_RUN_WRITE = (
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
)
### default Local database connection configuration
@ -170,8 +170,8 @@ class Config(metaclass=Singleton):
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True"
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
self.IS_LOAD_8BIT = False
# In order to be compatible with the new and old model parameter design
@ -187,7 +187,9 @@ class Config(metaclass=Singleton):
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
)
### Control whether to display the source document of knowledge on the front end.
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
)
### SUMMARY_CONFIG Configuration
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")

View File

@ -3,13 +3,12 @@
import os
# import nltk
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
MODEL_PATH = os.path.join(ROOT_PATH, "models")
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
LOGDIR = os.path.join(ROOT_PATH, "logs")
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
DATA_DIR = os.path.join(PILOT_PATH, "data")
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path

View File

@ -1,10 +1,11 @@
import logging
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from pilot.configs.config import Config
from pilot.common.schema import DBType
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -1,5 +1,12 @@
from pilot.embedding_engine.source_embedding import SourceEmbedding, register
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.knowledge_type import KnowledgeType
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
__all__ = [
"SourceEmbedding",
"register",
"EmbeddingEngine",
"KnowledgeType",
"PreTextSplitter",
]

View File

@ -0,0 +1,30 @@
from typing import Iterable, List
from langchain.schema import Document
from langchain.text_splitter import TextSplitter
def _single_document_split(
document: Document, pre_separator: str
) -> Iterable[Document]:
page_content = document.page_content
for i, content in enumerate(page_content.split(pre_separator)):
metadata = document.metadata.copy()
if "source" in metadata:
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
yield Document(page_content=content, metadata=metadata)
class PreTextSplitter(TextSplitter):
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
self.pre_separator = pre_separator
self._impl = text_splitter_impl
def split_text(self, text: str) -> List[str]:
return self._impl.split_text(text)
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
def generator() -> Iterable[Document]:
for doc in documents:
yield from _single_document_split(doc, pre_separator=self.pre_separator)
return self._impl.split_documents(generator())

View File

@ -5,12 +5,13 @@ from __future__ import annotations
import contextlib
import json
import re
import logging
from typing import Optional
from pilot.configs.config import Config
from pilot.json_utils.utilities import extract_char_position
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -3,12 +3,14 @@ import json
import os.path
import re
import json
import logging
from datetime import datetime
from jsonschema import Draft7Validator
from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"

View File

@ -3,6 +3,7 @@
import os
import re
import logging
from pathlib import Path
from typing import List, Tuple, Callable, Type
from functools import cache
@ -19,8 +20,8 @@ from pilot.model.parameter import (
)
from pilot.configs.model_config import get_device
from pilot.configs.config import Config
from pilot.logs import logger
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -13,6 +13,9 @@ from pilot.utils.api_utils import (
_api_remote as api_remote,
_sync_api_remote as sync_api_remote,
)
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__)
class BaseModelController(BaseComponent, ABC):
@ -59,7 +62,7 @@ class LocalModelController(BaseModelController):
async def get_all_instances(
self, model_name: str = None, healthy_only: bool = False
) -> List[ModelInstance]:
logging.info(
logger.info(
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
)
if not model_name:
@ -178,6 +181,13 @@ def run_model_controller():
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
ModelControllerParameters, env_prefix=env_prefix
)
setup_logging(
"pilot",
logging_level=controller_params.log_level,
logger_filename="dbgpt_model_controller.log",
)
initialize_controller(host=controller_params.host, port=controller_params.port)

View File

@ -120,7 +120,9 @@ class DefaultModelWorker(ModelWorker):
text=output, error_code=0, model_context=model_context
)
yield model_output
print(f"\n\nfull stream output:\n{previous_response}")
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
except Exception as e:
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if torch_imported and isinstance(e, torch.cuda.CudaError):

View File

@ -36,6 +36,7 @@ from pilot.utils.parameter_utils import (
ParameterDescription,
_dict_to_command_args,
)
from pilot.utils.utils import setup_logging
logger = logging.getLogger(__name__)
@ -885,6 +886,12 @@ def run_worker_manager(
model_name=model_name, model_path=model_path, standalone=standalone, port=port
)
setup_logging(
"pilot",
logging_level=worker_params.log_level,
logger_filename="dbgpt_model_worker_manager.log",
)
embedded_mod = True
logger.info(f"Worker params: {worker_params}")
if not app:

View File

@ -3,11 +3,13 @@ Fork from text-generation-webui https://github.com/oobabooga/text-generation-web
"""
import re
from typing import Dict
import logging
import torch
import llama_cpp
from pilot.model.parameter import LlamaCppModelParameters
from pilot.logs import logger
logger = logging.getLogger(__name__)
if torch.cuda.is_available() and not torch.version.hip:
try:

View File

@ -3,6 +3,7 @@
from typing import Optional, Dict
import logging
from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.parameter import (
@ -12,7 +13,8 @@ from pilot.model.parameter import (
)
from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
from pilot.logs import logger
logger = logging.getLogger(__name__)
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):

View File

@ -31,6 +31,21 @@ class ModelControllerParameters(BaseParameters):
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Controller in background"}
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
@dataclass
@ -85,6 +100,22 @@ class ModelWorkerParameters(BaseModelParameters):
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
)
log_level: Optional[str] = field(
default=None,
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
@dataclass
class BaseEmbeddingModelParameters(BaseModelParameters):

View File

@ -38,8 +38,6 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.common.schema import DBType
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
from pilot.scene.message import OnceConversation
@ -409,7 +407,7 @@ async def model_types(controller: BaseModelController = Depends(get_model_contro
@router.get("/v1/model/supports")
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
logger.info(f"/controller/model/supports")
try:
models = await worker_manager.supported_models()

View File

@ -6,12 +6,11 @@ from fastapi import (
)
from typing import List
import logging
from pilot.configs.config import Config
from pilot.scene.chat_factory import ChatFactory
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.openapi.api_view_model import (
Result,
@ -34,7 +33,8 @@ from pilot.scene.chat_db.data_loader import DbDataLoader
router = APIRouter()
CFG = Config()
CHAT_FACTORY = ChatFactory()
logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
logger = logging.getLogger(__name__)
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])

View File

@ -2,18 +2,17 @@ from __future__ import annotations
import json
from abc import ABC
import logging
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.model.base import ModelOutput
from pilot.utils import build_logger
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -1,11 +1,11 @@
import datetime
import traceback
import warnings
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.component import ComponentType
from pilot.memory.chat_history.base import BaseChatHistoryMemory
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
@ -14,10 +14,10 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
from pilot.prompts.prompt_new import PromptTemplate
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.scene.message import OnceConversation
from pilot.utils import build_logger, get_or_create_event_loop
from pilot.utils import get_or_create_event_loop
from pydantic import Extra
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
logger = logging.getLogger(__name__)
headers = {"User-Agent": "dbgpt Client"}
CFG = Config()

View File

@ -1,13 +1,12 @@
from typing import List
from decimal import Decimal
import logging
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
CFG = Config()
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
logger = logging.getLogger(__name__)
class DashboardDataLoader:

View File

@ -1,8 +1,8 @@
import json
import logging
from typing import NamedTuple, List
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.scene.base import ChatScene
@ -13,7 +13,7 @@ class ChartItem(NamedTuple):
showcase: str
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
logger = logging.getLogger(__name__)
class ChatDashboardOutputParser(BaseOutputParser):

View File

@ -1,11 +1,7 @@
import json
import re
from abc import ABC, abstractmethod
import logging
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
@ -17,7 +13,7 @@ class ExcelAnalyzeResponse(NamedTuple):
display: str
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
logger = logging.getLogger(__name__)
class ChatExcelOutputParser(BaseOutputParser):

View File

@ -1,11 +1,7 @@
import json
import re
from abc import ABC, abstractmethod
import logging
from typing import Dict, NamedTuple, List
import pandas as pd
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
CFG = Config()
@ -17,7 +13,7 @@ class ExcelResponse(NamedTuple):
plans: List
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
logger = logging.getLogger(__name__)
class LearningExcelOutputParser(BaseOutputParser):

View File

@ -1,8 +1,7 @@
import json
from typing import Dict, NamedTuple
from pilot.utils import build_logger
import logging
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader
@ -14,7 +13,7 @@ class SqlAction(NamedTuple):
thoughts: Dict
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
class DbChatOutputParser(BaseOutputParser):

View File

@ -1,8 +1,4 @@
from pilot.configs.model_config import LOGDIR
from pilot.out_parser.base import BaseOutputParser, T
from pilot.utils import build_logger
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class NormalChatOutputParser(BaseOutputParser):

View File

@ -1,11 +1,10 @@
import json
import logging
from typing import Dict, NamedTuple
from pilot.utils import build_logger
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
class PluginAction(NamedTuple):

View File

@ -1,9 +1,7 @@
from pilot.utils import build_logger
import logging
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser):

View File

@ -60,18 +60,25 @@ class ChatKnowledge(BaseChat):
async def stream_call(self):
input_values = self.generate_input_values()
# Source of knowledge file
relations = input_values.get("relations")
last_output = None
async for output in super().stream_call():
# Source of knowledge file
relations = input_values.get("relations")
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and type(relations) == list
and len(relations) > 0
and hasattr(output, "text")
):
output.text = output.text + "\trelations:" + ",".join(relations)
last_output = output
yield output
if (
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
and last_output
and type(relations) == list
and len(relations) > 0
and hasattr(last_output, "text")
):
last_output.text = (
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
)
yield last_output
def generate_input_values(self):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]

View File

@ -1,9 +1,8 @@
from pilot.utils import build_logger
import logging
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser):

View File

@ -1,9 +1,8 @@
from pilot.utils import build_logger
import logging
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
logger = logging.getLogger(__name__)
class NormalChatOutputParser(BaseOutputParser):

View File

@ -108,7 +108,7 @@ class WebWerverParameters(BaseParameters):
},
)
log_level: Optional[str] = field(
default="INFO",
default=None,
metadata={
"help": "Logging level",
"valid_values": [

View File

@ -33,7 +33,11 @@ from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.cluster import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
from pilot.utils.utils import (
setup_logging,
_get_logging_level,
logging_str_to_uvicorn_level,
)
static_file_path = os.path.join(os.getcwd(), "server/static")
@ -111,7 +115,11 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
)
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
setup_logging(logging_level=param.log_level)
if not param.log_level:
param.log_level = _get_logging_level()
setup_logging(
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
)
# Before start
system_app.before_start()

View File

@ -1,9 +1,12 @@
import click
import logging
import os
import functools
from pilot.configs.model_config import DATASETS_DIR
API_ADDRESS: str = "http://127.0.0.1:5000"
_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
API_ADDRESS: str = _DEFAULT_API_ADDRESS
logger = logging.getLogger("dbgpt_cli")
@ -20,33 +23,44 @@ logger = logging.getLogger("dbgpt_cli")
def knowledge_cli_group(address: str):
"""Knowledge command line tool"""
global API_ADDRESS
if address == _DEFAULT_API_ADDRESS:
address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS)
API_ADDRESS = address
def add_knowledge_options(func):
@click.option(
"--space_name",
required=False,
type=str,
default="default",
show_default=True,
help="Your knowledge space name",
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@knowledge_cli_group.command()
@click.option(
"--vector_name",
required=False,
type=str,
default="default",
show_default=True,
help="Your vector store name",
)
@add_knowledge_options
@click.option(
"--vector_store_type",
required=False,
type=str,
default="Chroma",
show_default=True,
help="Vector store type",
help="Vector store type.",
)
@click.option(
"--local_doc_dir",
"--local_doc_path",
required=False,
type=str,
default=DATASETS_DIR,
show_default=True,
help="Your document directory",
help="Your document directory or document file path.",
)
@click.option(
"--skip_wrong_doc",
@ -54,31 +68,165 @@ def knowledge_cli_group(address: str):
type=bool,
default=False,
is_flag=True,
show_default=True,
help="Skip wrong document",
help="Skip wrong document.",
)
@click.option(
"--overwrite",
required=False,
type=bool,
default=False,
is_flag=True,
help="Overwrite existing document(they has same name).",
)
@click.option(
"--max_workers",
required=False,
type=int,
default=None,
help="The maximum number of threads that can be used to upload document",
help="The maximum number of threads that can be used to upload document.",
)
@click.option(
"--pre_separator",
required=False,
type=str,
default=None,
help="Preseparator, this separator is used for pre-splitting before the document is "
"actually split by the text splitter. Preseparator are not included in the vectorized text. ",
)
@click.option(
"--separator",
required=False,
type=str,
default=None,
help="This is the document separator. Currently, only one separator is supported.",
)
@click.option(
"--chunk_size",
required=False,
type=int,
default=None,
help="Maximum size of chunks to split.",
)
@click.option(
"--chunk_overlap",
required=False,
type=int,
default=None,
help="Overlap in characters between chunks.",
)
def load(
vector_name: str,
space_name: str,
vector_store_type: str,
local_doc_dir: str,
local_doc_path: str,
skip_wrong_doc: bool,
overwrite: bool,
max_workers: int,
pre_separator: str,
separator: str,
chunk_size: int,
chunk_overlap: int,
):
"""Load your local knowledge to DB-GPT"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
knowledge_init(
API_ADDRESS,
vector_name,
space_name,
vector_store_type,
local_doc_dir,
local_doc_path,
skip_wrong_doc,
overwrite,
max_workers,
pre_separator,
separator,
chunk_size,
chunk_overlap,
)
@knowledge_cli_group.command()
@add_knowledge_options
@click.option(
"--doc_name",
required=False,
type=str,
default=None,
help="The document name you want to delete. If doc_name is None, this command will delete the whole space.",
)
@click.option(
"-y",
required=False,
type=bool,
default=False,
is_flag=True,
help="Confirm your choice",
)
def delete(space_name: str, doc_name: str, y: bool):
"""Delete your knowledge space or document in space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_delete
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
@knowledge_cli_group.command()
@click.option(
"--space_name",
required=False,
type=str,
default=None,
show_default=True,
help="Your knowledge space name. If None, list all spaces",
)
@click.option(
"--doc_id",
required=False,
type=int,
default=None,
show_default=True,
help="Your document id in knowledge space. If Not None, list all chunks in current document",
)
@click.option(
"--page",
required=False,
type=int,
default=1,
show_default=True,
help="The page for every query",
)
@click.option(
"--page_size",
required=False,
type=int,
default=20,
show_default=True,
help="The page size for every query",
)
@click.option(
"--show_content",
required=False,
type=bool,
default=False,
is_flag=True,
help="Query the document content of chunks",
)
@click.option(
"--output",
required=False,
type=click.Choice(["text", "html", "csv", "latex", "json"]),
default="text",
help="The output format",
)
def list(
space_name: str,
doc_id: int,
page: int,
page_size: int,
show_content: bool,
output: str,
):
"""List knowledge space"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_list
knowledge_list(
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output
)

View File

@ -62,6 +62,9 @@ class KnowledgeApiClient(ApiClient):
else:
raise e
def space_delete(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/delete", data=request)
def space_list(self, request: KnowledgeSpaceRequest):
return self._post("/knowledge/space/list", data=request)
@ -69,6 +72,10 @@ class KnowledgeApiClient(ApiClient):
url = f"/knowledge/{space_name}/document/add"
return self._post(url, data=request)
def document_delete(self, space_name: str, request: KnowledgeDocumentRequest):
url = f"/knowledge/{space_name}/document/delete"
return self._post(url, data=request)
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
url = f"/knowledge/{space_name}/document/list"
return self._post(url, data=query_request)
@ -97,15 +104,20 @@ class KnowledgeApiClient(ApiClient):
def knowledge_init(
api_address: str,
vector_name: str,
space_name: str,
vector_store_type: str,
local_doc_dir: str,
local_doc_path: str,
skip_wrong_doc: bool,
max_workers: int = None,
overwrite: bool,
max_workers: int,
pre_separator: str,
separator: str,
chunk_size: int,
chunk_overlap: int,
):
client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest()
space.name = vector_name
space.name = space_name
space.desc = "DB-GPT cli"
space.vector_type = vector_store_type
space.owner = "DB-GPT"
@ -124,24 +136,260 @@ def knowledge_init(
def upload(filename: str):
try:
logger.info(f"Begin upload document: {filename} to {space.name}")
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
doc_id = None
try:
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
except Exception as ex:
if overwrite and "have already named" in str(ex):
logger.warn(
f"Document {filename} already exist in space {space.name}, overwrite it"
)
client.document_delete(
space.name, KnowledgeDocumentRequest(doc_name=filename)
)
doc_id = client.document_upload(
space.name, filename, KnowledgeType.DOCUMENT.value, filename
)
else:
raise ex
sync_req = DocumentSyncRequest(doc_ids=[doc_id])
if pre_separator:
sync_req.pre_separator = pre_separator
if separator:
sync_req.separators = [separator]
if chunk_size:
sync_req.chunk_size = chunk_size
if chunk_overlap:
sync_req.chunk_overlap = chunk_overlap
client.document_sync(space.name, sync_req)
return doc_id
except Exception as e:
if skip_wrong_doc:
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
else:
raise e
if not os.path.exists(local_doc_path):
raise Exception(f"{local_doc_path} not exists")
with ThreadPoolExecutor(max_workers=max_workers) as pool:
tasks = []
for root, _, files in os.walk(local_doc_dir, topdown=False):
for file in files:
filename = os.path.join(root, file)
tasks.append(pool.submit(upload, filename))
file_names = []
if os.path.isdir(local_doc_path):
for root, _, files in os.walk(local_doc_path, topdown=False):
for file in files:
file_names.append(os.path.join(root, file))
else:
# Single file
file_names.append(local_doc_path)
[tasks.append(pool.submit(upload, filename)) for filename in file_names]
doc_ids = [r.result() for r in as_completed(tasks)]
doc_ids = list(filter(lambda x: x, doc_ids))
if not doc_ids:
logger.warn("Warning: no document to sync")
return
from prettytable import PrettyTable
class _KnowledgeVisualizer:
def __init__(self, api_address: str, out_format: str):
self.client = KnowledgeApiClient(api_address)
self.out_format = out_format
self.out_kwargs = {}
if out_format == "json":
self.out_kwargs["ensure_ascii"] = False
def print_table(self, table):
print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs))
def list_spaces(self):
spaces = self.client.space_list(KnowledgeSpaceRequest())
table = PrettyTable(
["Space ID", "Space Name", "Vector Type", "Owner", "Description"],
title="All knowledge spaces",
)
for sp in spaces:
context = sp.get("context")
table.add_row(
[
sp.get("id"),
sp.get("name"),
sp.get("vector_type"),
sp.get("owner"),
sp.get("desc"),
]
)
self.print_table(table)
def list_documents(self, space_name: str, page: int, page_size: int):
space_data = self.client.document_list(
space_name, DocumentQueryRequest(page=page, page_size=page_size)
)
space_table = PrettyTable(
[
"Space Name",
"Total Documents",
"Current Page",
"Current Size",
"Page Size",
],
title=f"Space {space_name} description",
)
space_table.add_row(
[space_name, space_data["total"], page, len(space_data["data"]), page_size]
)
table = PrettyTable(
[
"Space Name",
"Document ID",
"Document Name",
"Type",
"Chunks",
"Last Sync",
"Status",
"Result",
],
title=f"Documents of space {space_name}",
)
for doc in space_data["data"]:
table.add_row(
[
space_name,
doc.get("id"),
doc.get("doc_name"),
doc.get("doc_type"),
doc.get("chunk_size"),
doc.get("last_sync"),
doc.get("status"),
doc.get("result"),
]
)
if self.out_format == "text":
self.print_table(space_table)
print("")
self.print_table(table)
def list_chunks(
self,
space_name: str,
doc_id: int,
page: int,
page_size: int,
show_content: bool,
):
doc_data = self.client.chunk_list(
space_name,
ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size),
)
doc_table = PrettyTable(
[
"Space Name",
"Document ID",
"Total Chunks",
"Current Page",
"Current Size",
"Page Size",
],
title=f"Document {doc_id} in {space_name} description",
)
doc_table.add_row(
[
space_name,
doc_id,
doc_data["total"],
page,
len(doc_data["data"]),
page_size,
]
)
table = PrettyTable(
["Space Name", "Document ID", "Document Name", "Content", "Meta Data"],
title=f"chunks of document id {doc_id} in space {space_name}",
)
for chunk in doc_data["data"]:
table.add_row(
[
space_name,
doc_id,
chunk.get("doc_name"),
chunk.get("content") if show_content else "[Hidden]",
chunk.get("meta_info"),
]
)
if self.out_format == "text":
self.print_table(doc_table)
print("")
self.print_table(table)
def knowledge_list(
api_address: str,
space_name: str,
page: int,
page_size: int,
doc_id: int,
show_content: bool,
out_format: str,
):
visualizer = _KnowledgeVisualizer(api_address, out_format)
if not space_name:
visualizer.list_spaces()
elif not doc_id:
visualizer.list_documents(space_name, page, page_size)
else:
visualizer.list_chunks(space_name, doc_id, page, page_size, show_content)
def knowledge_delete(
api_address: str, space_name: str, doc_name: str, confirm: bool = False
):
client = KnowledgeApiClient(api_address)
space = KnowledgeSpaceRequest()
space.name = space_name
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
if not space_list:
raise Exception(f"No knowledge space name {space_name}")
if not doc_name:
if not confirm:
# Confirm by user
user_input = (
input(
f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: "
)
.strip()
.lower()
)
if user_input != "yes":
logger.warn("Delete operation cancelled.")
return
client.space_delete(space)
logger.info("Delete the whole knowledge space successfully!")
else:
if not confirm:
# Confirm by user
user_input = (
input(
f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: "
)
.strip()
.lower()
)
if user_input != "yes":
logger.warn("Delete operation cancelled.")
return
client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name))
logger.info(
f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!"
)

View File

@ -1,6 +1,7 @@
import os
import shutil
import tempfile
import logging
from fastapi import APIRouter, File, UploadFile, Form
@ -27,6 +28,8 @@ from pilot.server.knowledge.request.request import (
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
logger = logging.getLogger(__name__)
CFG = Config()
router = APIRouter()
@ -159,10 +162,10 @@ async def document_upload(
@router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest):
print(f"Received params: {space_name}, {request}")
logger.info(f"Received params: {space_name}, {request}")
try:
knowledge_space_service.sync_knowledge_document(
space_name=space_name, doc_ids=request.doc_ids
space_name=space_name, sync_request=request
)
return Result.succ([])
except Exception as e:

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from pydantic import BaseModel
from fastapi import UploadFile
@ -54,10 +54,25 @@ class DocumentQueryRequest(BaseModel):
class DocumentSyncRequest(BaseModel):
"""doc_ids: doc ids"""
"""Sync request"""
"""doc_ids: doc ids"""
doc_ids: List
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
Preseparator are not included in the vectorized text.
"""
pre_separator: Optional[str] = None
"""Custom separators"""
separators: Optional[List[str]] = None
"""Custom chunk size"""
chunk_size: Optional[int] = None
"""Custom chunk overlap"""
chunk_overlap: Optional[int] = None
class ChunkQueryRequest(BaseModel):
"""id: id"""

View File

@ -1,5 +1,5 @@
import json
import threading
import logging
from datetime import datetime
from pilot.vector_store.connector import VectorStoreConnector
@ -12,7 +12,6 @@ from pilot.configs.model_config import (
from pilot.component import ComponentType
from pilot.utils.executor_utils import ExecutorFactory
from pilot.logs import logger
from pilot.server.knowledge.chunk_db import (
DocumentChunkEntity,
DocumentChunkDao,
@ -31,6 +30,7 @@ from pilot.server.knowledge.request.request import (
DocumentQueryRequest,
ChunkQueryRequest,
SpaceArgumentRequest,
DocumentSyncRequest,
)
from enum import Enum
@ -44,6 +44,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
knowledge_document_dao = KnowledgeDocumentDao()
document_chunk_dao = DocumentChunkDao()
logger = logging.getLogger(__name__)
CFG = Config()
@ -107,7 +108,6 @@ class KnowledgeService:
res.owner = space.owner
res.gmt_created = space.gmt_created
res.gmt_modified = space.gmt_modified
res.owner = space.owner
res.context = space.context
query = KnowledgeDocumentEntity(space=space.name)
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
@ -155,9 +155,10 @@ class KnowledgeService:
"""sync knowledge document chunk into vector store"""
def sync_knowledge_document(self, space_name, doc_ids):
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
from langchain.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
@ -165,6 +166,7 @@ class KnowledgeService:
# import langchain is very very slow!!!
doc_ids = sync_request.doc_ids
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(
id=doc_id,
@ -190,24 +192,43 @@ class KnowledgeService:
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
if sync_request.chunk_size:
chunk_size = sync_request.chunk_size
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multiple separators"
)
try:
separator = "\n\n" if not separators else separators[0]
text_splitter = SpacyTextSplitter(
separator=separator,
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if sync_request.pre_separator:
logger.info(f"Use preseparator, {sync_request.pre_separator}")
text_splitter = PreTextSplitter(
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)

View File

@ -1,6 +1,6 @@
"""ElevenLabs speech module"""
import os
import logging
import requests
from pilot.configs.config import Config
@ -8,6 +8,8 @@ from pilot.speech.base import VoiceBase
PLACEHOLDERS = {"your-voice-id"}
logger = logging.getLogger(__name__)
class ElevenLabsSpeech(VoiceBase):
"""ElevenLabs speech class"""
@ -68,7 +70,6 @@ class ElevenLabsSpeech(VoiceBase):
Returns:
bool: True if the request was successful, False otherwise
"""
from pilot.logs import logger
from playsound import playsound
tts_url = (

View File

@ -1,5 +1,6 @@
import json
import uuid
import logging
from pilot.common.schema import DBType
from pilot.component import SystemApp
@ -7,16 +8,14 @@ from pilot.configs.config import Config
from pilot.configs.model_config import (
KNOWLEDGE_UPLOAD_ROOT_PATH,
EMBEDDING_MODEL_CONFIG,
LOGDIR,
)
from pilot.scene.base import ChatScene
from pilot.scene.base_chat import BaseChat
from pilot.scene.chat_factory import ChatFactory
from pilot.summary.rdbms_db_summary import RdbmsSummary
from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
logger = logging.getLogger(__name__)
CFG = Config()
chat_factory = ChatFactory()

View File

@ -1,6 +1,5 @@
from .utils import (
get_gpu_memory,
build_logger,
StreamToLogger,
disable_torch_init,
pretty_print_semaphore,

View File

@ -5,6 +5,8 @@ from dataclasses import is_dataclass, asdict
T = TypeVar("T")
logger = logging.getLogger(__name__)
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
import typing_inspect
@ -21,7 +23,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
raise TypeError("Return type must be annotated in the decorated function.")
actual_dataclass = _extract_dataclass_from_generic(return_type)
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
if not actual_dataclass:
actual_dataclass = return_type
sig = signature(func)
@ -55,7 +57,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
else: # For GET, DELETE, etc.
request_params["params"] = request_data
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
return return_type, actual_dataclass, request_params

View File

@ -20,7 +20,7 @@ def _get_logging_level() -> str:
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
def setup_logging(logging_level=None, logger_name: str = None):
def setup_logging_level(logging_level=None, logger_name: str = None):
if not logging_level:
logging_level = _get_logging_level()
if type(logging_level) is str:
@ -32,6 +32,19 @@ def setup_logging(logging_level=None, logger_name: str = None):
logging.basicConfig(level=logging_level, encoding="utf-8")
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
if not logging_level:
logging_level = _get_logging_level()
logger = _build_logger(logger_name, logging_level, logger_filename)
try:
import coloredlogs
color_level = logging_level if logging_level else "INFO"
coloredlogs.install(level=color_level, logger=logger)
except ImportError:
pass
def get_gpu_memory(max_gpus=None):
import torch
@ -53,7 +66,7 @@ def get_gpu_memory(max_gpus=None):
return gpu_memory
def build_logger(logger_name, logger_filename):
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
global handler
formatter = logging.Formatter(
@ -63,7 +76,7 @@ def build_logger(logger_name, logger_filename):
# Set the format of root handlers
if not logging.getLogger().handlers:
setup_logging()
setup_logging_level(logging_level=logging_level)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
@ -78,7 +91,7 @@ def build_logger(logger_name, logger_filename):
# sys.stderr = sl
# Add a file handler for all loggers
if handler is None:
if handler is None and logger_filename:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
@ -89,11 +102,9 @@ def build_logger(logger_name, logger_filename):
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
setup_logging()
# Get logger
logger = logging.getLogger(logger_name)
setup_logging(logger_name=logger_name)
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
return logger

View File

@ -1,11 +1,13 @@
import os
import logging
from typing import Any
from chromadb.config import Settings
from chromadb import PersistentClient
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class ChromaStore(VectorStoreBase):
"""chroma database"""

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple
from pymilvus import Collection, DataType, connections, utility
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
class MilvusStore(VectorStoreBase):
"""Milvus database"""

View File

@ -1,5 +1,6 @@
import os
import json
import logging
import weaviate
from langchain.schema import Document
from langchain.vectorstores import Weaviate
@ -7,9 +8,9 @@ from weaviate.exceptions import WeaviateBaseError
from pilot.configs.config import Config
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase
logger = logging.getLogger(__name__)
CFG = Config()

View File

@ -287,6 +287,7 @@ def core_requires():
]
setup_spec.extras["framework"] = [
"coloredlogs",
"httpx",
"sqlparse==0.4.4",
"seaborn",