mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
feat(ChatKnowledge): Add custom text separators and refactor log configuration
This commit is contained in:
parent
20bdddec51
commit
5dfe611478
@ -55,6 +55,8 @@ EMBEDDING_MODEL=text2vec
|
|||||||
#EMBEDDING_MODEL=bge-large-zh
|
#EMBEDDING_MODEL=bge-large-zh
|
||||||
KNOWLEDGE_CHUNK_SIZE=500
|
KNOWLEDGE_CHUNK_SIZE=500
|
||||||
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||||
|
# Control whether to display the source document of knowledge on the front end.
|
||||||
|
KNOWLEDGE_CHAT_SHOW_RELATIONS=False
|
||||||
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
## EMBEDDING_TOKENIZER - Tokenizer to use for chunking large inputs
|
||||||
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
## EMBEDDING_TOKEN_LIMIT - Chunk size limit for large inputs
|
||||||
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
# EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||||
@ -155,3 +157,5 @@ SUMMARY_CONFIG=FAST
|
|||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
|
# FATAL, ERROR, WARNING, WARNING, INFO, DEBUG, NOTSET
|
||||||
DBGPT_LOG_LEVEL=INFO
|
DBGPT_LOG_LEVEL=INFO
|
||||||
|
# LOG dir, default: ./logs
|
||||||
|
#DBGPT_LOG_DIR=
|
@ -1,6 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
import logging
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from regex import regex
|
from regex import regex
|
||||||
@ -11,9 +12,9 @@ from pilot.json_utils.json_fix_general import (
|
|||||||
balance_braces,
|
balance_braces,
|
||||||
fix_invalid_escape,
|
fix_invalid_escape,
|
||||||
)
|
)
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,14 +2,15 @@
|
|||||||
import io
|
import io
|
||||||
import uuid
|
import uuid
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
|
import logging
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from pilot.commands.command_mange import command
|
from pilot.commands.command_mange import command
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
|
import logging
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
from pilot.commands.command_mange import command
|
from pilot.commands.command_mange import command
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -13,11 +13,10 @@ import matplotlib.pyplot as plt
|
|||||||
import matplotlib.ticker as mtick
|
import matplotlib.ticker as mtick
|
||||||
from matplotlib.font_manager import FontManager
|
from matplotlib.font_manager import FontManager
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
from pilot.commands.command_mange import command
|
from pilot.commands.command_mange import command
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@command(
|
@command(
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
|
||||||
from pilot.commands.command_mange import command
|
from pilot.commands.command_mange import command
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@command(
|
@command(
|
||||||
|
@ -8,6 +8,7 @@ import zipfile
|
|||||||
import requests
|
import requests
|
||||||
import threading
|
import threading
|
||||||
import datetime
|
import datetime
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, TYPE_CHECKING
|
from typing import List, TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -17,7 +18,8 @@ import requests
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import PLUGINS_DIR
|
from pilot.configs.model_config import PLUGINS_DIR
|
||||||
from pilot.logs import logger
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||||
|
@ -32,7 +32,7 @@ class Config(metaclass=Singleton):
|
|||||||
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
||||||
|
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
# User agent header to use when making HTTP requests
|
# User agent header to use when making HTTP requests
|
||||||
# Some websites might just completely deny request with an error code if
|
# Some websites might just completely deny request with an error code if
|
||||||
@ -64,7 +64,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
self.milvus_username = os.getenv("MILVUS_USERNAME")
|
||||||
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
self.milvus_password = os.getenv("MILVUS_PASSWORD")
|
||||||
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
self.milvus_collection = os.getenv("MILVUS_COLLECTION", "dbgpt")
|
||||||
self.milvus_secure = os.getenv("MILVUS_SECURE") == "True"
|
self.milvus_secure = os.getenv("MILVUS_SECURE", "False").lower() == "true"
|
||||||
|
|
||||||
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
self.authorise_key = os.getenv("AUTHORISE_COMMAND_KEY", "y")
|
||||||
self.exit_key = os.getenv("EXIT_KEY", "n")
|
self.exit_key = os.getenv("EXIT_KEY", "n")
|
||||||
@ -98,7 +98,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.disabled_command_categories = []
|
self.disabled_command_categories = []
|
||||||
|
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False").lower() == "true"
|
||||||
)
|
)
|
||||||
### message stor file
|
### message stor file
|
||||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||||
@ -107,7 +107,7 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||||
self.plugins_openai = []
|
self.plugins_openai = []
|
||||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True").lower() == "true"
|
||||||
|
|
||||||
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
self.plugins_git_branch = os.getenv("PLUGINS_GIT_BRANCH", "plugin_dashboard")
|
||||||
|
|
||||||
@ -124,10 +124,10 @@ class Config(metaclass=Singleton):
|
|||||||
self.plugins_denylist = []
|
self.plugins_denylist = []
|
||||||
### Native SQL Execution Capability Control Configuration
|
### Native SQL Execution Capability Control Configuration
|
||||||
self.NATIVE_SQL_CAN_RUN_DDL = (
|
self.NATIVE_SQL_CAN_RUN_DDL = (
|
||||||
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True") == "True"
|
os.getenv("NATIVE_SQL_CAN_RUN_DDL", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
self.NATIVE_SQL_CAN_RUN_WRITE = (
|
||||||
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True") == "True"
|
os.getenv("NATIVE_SQL_CAN_RUN_WRITE", "True").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
### default Local database connection configuration
|
### default Local database connection configuration
|
||||||
@ -170,8 +170,8 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
# QLoRA
|
# QLoRA
|
||||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True"
|
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True").lower() == "true"
|
||||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False").lower() == "true"
|
||||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||||
self.IS_LOAD_8BIT = False
|
self.IS_LOAD_8BIT = False
|
||||||
# In order to be compatible with the new and old model parameter design
|
# In order to be compatible with the new and old model parameter design
|
||||||
@ -187,7 +187,9 @@ class Config(metaclass=Singleton):
|
|||||||
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
os.getenv("KNOWLEDGE_SEARCH_MAX_TOKEN", 2000)
|
||||||
)
|
)
|
||||||
### Control whether to display the source document of knowledge on the front end.
|
### Control whether to display the source document of knowledge on the front end.
|
||||||
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = False
|
self.KNOWLEDGE_CHAT_SHOW_RELATIONS = (
|
||||||
|
os.getenv("KNOWLEDGE_CHAT_SHOW_RELATIONS", "False").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
### SUMMARY_CONFIG Configuration
|
### SUMMARY_CONFIG Configuration
|
||||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||||
|
@ -3,13 +3,12 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# import nltk
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
MODEL_PATH = os.path.join(ROOT_PATH, "models")
|
||||||
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
PILOT_PATH = os.path.join(ROOT_PATH, "pilot")
|
||||||
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
VECTORE_PATH = os.path.join(PILOT_PATH, "vector_store")
|
||||||
LOGDIR = os.path.join(ROOT_PATH, "logs")
|
LOGDIR = os.getenv("DBGPT_LOG_DIR", os.path.join(ROOT_PATH, "logs"))
|
||||||
|
|
||||||
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
DATASETS_DIR = os.path.join(PILOT_PATH, "datasets")
|
||||||
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
DATA_DIR = os.path.join(PILOT_PATH, "data")
|
||||||
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
# nltk.data.path = [os.path.join(PILOT_PATH, "nltk_data")] + nltk.data.path
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
import logging
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,12 @@
|
|||||||
from pilot.embedding_engine.source_embedding import SourceEmbedding, register
|
from pilot.embedding_engine.source_embedding import SourceEmbedding, register
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
from pilot.embedding_engine.knowledge_type import KnowledgeType
|
||||||
|
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||||
|
|
||||||
__all__ = ["SourceEmbedding", "register", "EmbeddingEngine", "KnowledgeType"]
|
__all__ = [
|
||||||
|
"SourceEmbedding",
|
||||||
|
"register",
|
||||||
|
"EmbeddingEngine",
|
||||||
|
"KnowledgeType",
|
||||||
|
"PreTextSplitter",
|
||||||
|
]
|
||||||
|
30
pilot/embedding_engine/pre_text_splitter.py
Normal file
30
pilot/embedding_engine/pre_text_splitter.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from typing import Iterable, List
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.text_splitter import TextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
def _single_document_split(
|
||||||
|
document: Document, pre_separator: str
|
||||||
|
) -> Iterable[Document]:
|
||||||
|
page_content = document.page_content
|
||||||
|
for i, content in enumerate(page_content.split(pre_separator)):
|
||||||
|
metadata = document.metadata.copy()
|
||||||
|
if "source" in metadata:
|
||||||
|
metadata["source"] = metadata["source"] + "_pre_split_" + str(i)
|
||||||
|
yield Document(page_content=content, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class PreTextSplitter(TextSplitter):
|
||||||
|
def __init__(self, pre_separator: str, text_splitter_impl: TextSplitter):
|
||||||
|
self.pre_separator = pre_separator
|
||||||
|
self._impl = text_splitter_impl
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
return self._impl.split_text(text)
|
||||||
|
|
||||||
|
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
|
||||||
|
def generator() -> Iterable[Document]:
|
||||||
|
for doc in documents:
|
||||||
|
yield from _single_document_split(doc, pre_separator=self.pre_separator)
|
||||||
|
|
||||||
|
return self._impl.split_documents(generator())
|
@ -5,12 +5,13 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.json_utils.utilities import extract_char_position
|
from pilot.json_utils.utilities import extract_char_position
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,12 +3,14 @@ import json
|
|||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from jsonschema import Draft7Validator
|
from jsonschema import Draft7Validator
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.logs import logger
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
|
LLM_DEFAULT_RESPONSE_FORMAT = "llm_response_format_1"
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Callable, Type
|
from typing import List, Tuple, Callable, Type
|
||||||
from functools import cache
|
from functools import cache
|
||||||
@ -19,8 +20,8 @@ from pilot.model.parameter import (
|
|||||||
)
|
)
|
||||||
from pilot.configs.model_config import get_device
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.logs import logger
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -13,6 +13,9 @@ from pilot.utils.api_utils import (
|
|||||||
_api_remote as api_remote,
|
_api_remote as api_remote,
|
||||||
_sync_api_remote as sync_api_remote,
|
_sync_api_remote as sync_api_remote,
|
||||||
)
|
)
|
||||||
|
from pilot.utils.utils import setup_logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelController(BaseComponent, ABC):
|
class BaseModelController(BaseComponent, ABC):
|
||||||
@ -59,7 +62,7 @@ class LocalModelController(BaseModelController):
|
|||||||
async def get_all_instances(
|
async def get_all_instances(
|
||||||
self, model_name: str = None, healthy_only: bool = False
|
self, model_name: str = None, healthy_only: bool = False
|
||||||
) -> List[ModelInstance]:
|
) -> List[ModelInstance]:
|
||||||
logging.info(
|
logger.info(
|
||||||
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
f"Get all instances with {model_name}, healthy_only: {healthy_only}"
|
||||||
)
|
)
|
||||||
if not model_name:
|
if not model_name:
|
||||||
@ -178,6 +181,13 @@ def run_model_controller():
|
|||||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||||
ModelControllerParameters, env_prefix=env_prefix
|
ModelControllerParameters, env_prefix=env_prefix
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setup_logging(
|
||||||
|
"pilot",
|
||||||
|
logging_level=controller_params.log_level,
|
||||||
|
logger_filename="dbgpt_model_controller.log",
|
||||||
|
)
|
||||||
|
|
||||||
initialize_controller(host=controller_params.host, port=controller_params.port)
|
initialize_controller(host=controller_params.host, port=controller_params.port)
|
||||||
|
|
||||||
|
|
||||||
|
@ -120,7 +120,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
text=output, error_code=0, model_context=model_context
|
text=output, error_code=0, model_context=model_context
|
||||||
)
|
)
|
||||||
yield model_output
|
yield model_output
|
||||||
print(f"\n\nfull stream output:\n{previous_response}")
|
print(
|
||||||
|
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||||
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||||
|
@ -36,6 +36,7 @@ from pilot.utils.parameter_utils import (
|
|||||||
ParameterDescription,
|
ParameterDescription,
|
||||||
_dict_to_command_args,
|
_dict_to_command_args,
|
||||||
)
|
)
|
||||||
|
from pilot.utils.utils import setup_logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -885,6 +886,12 @@ def run_worker_manager(
|
|||||||
model_name=model_name, model_path=model_path, standalone=standalone, port=port
|
model_name=model_name, model_path=model_path, standalone=standalone, port=port
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setup_logging(
|
||||||
|
"pilot",
|
||||||
|
logging_level=worker_params.log_level,
|
||||||
|
logger_filename="dbgpt_model_worker_manager.log",
|
||||||
|
)
|
||||||
|
|
||||||
embedded_mod = True
|
embedded_mod = True
|
||||||
logger.info(f"Worker params: {worker_params}")
|
logger.info(f"Worker params: {worker_params}")
|
||||||
if not app:
|
if not app:
|
||||||
|
@ -3,11 +3,13 @@ Fork from text-generation-webui https://github.com/oobabooga/text-generation-web
|
|||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import llama_cpp
|
import llama_cpp
|
||||||
|
|
||||||
from pilot.model.parameter import LlamaCppModelParameters
|
from pilot.model.parameter import LlamaCppModelParameters
|
||||||
from pilot.logs import logger
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if torch.cuda.is_available() and not torch.version.hip:
|
if torch.cuda.is_available() and not torch.version.hip:
|
||||||
try:
|
try:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
import logging
|
||||||
from pilot.configs.model_config import get_device
|
from pilot.configs.model_config import get_device
|
||||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
@ -12,7 +13,8 @@ from pilot.model.parameter import (
|
|||||||
)
|
)
|
||||||
from pilot.utils import get_gpu_memory
|
from pilot.utils import get_gpu_memory
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||||
from pilot.logs import logger
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||||
|
@ -31,6 +31,21 @@ class ModelControllerParameters(BaseParameters):
|
|||||||
daemon: Optional[bool] = field(
|
daemon: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Run Model Controller in background"}
|
default=False, metadata={"help": "Run Model Controller in background"}
|
||||||
)
|
)
|
||||||
|
log_level: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Logging level",
|
||||||
|
"valid_values": [
|
||||||
|
"FATAL",
|
||||||
|
"ERROR",
|
||||||
|
"WARNING",
|
||||||
|
"WARNING",
|
||||||
|
"INFO",
|
||||||
|
"DEBUG",
|
||||||
|
"NOTSET",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -85,6 +100,22 @@ class ModelWorkerParameters(BaseModelParameters):
|
|||||||
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
|
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log_level: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Logging level",
|
||||||
|
"valid_values": [
|
||||||
|
"FATAL",
|
||||||
|
"ERROR",
|
||||||
|
"WARNING",
|
||||||
|
"WARNING",
|
||||||
|
"INFO",
|
||||||
|
"DEBUG",
|
||||||
|
"NOTSET",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseEmbeddingModelParameters(BaseModelParameters):
|
class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||||
|
@ -38,8 +38,6 @@ from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
|||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
@ -409,7 +407,7 @@ async def model_types(controller: BaseModelController = Depends(get_model_contro
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/v1/model/supports")
|
@router.get("/v1/model/supports")
|
||||||
async def model_types(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
async def model_supports(worker_manager: WorkerManager = Depends(get_worker_manager)):
|
||||||
logger.info(f"/controller/model/supports")
|
logger.info(f"/controller/model/supports")
|
||||||
try:
|
try:
|
||||||
models = await worker_manager.supported_models()
|
models = await worker_manager.supported_models()
|
||||||
|
@ -6,12 +6,11 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import logging
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
from pilot.openapi.api_view_model import (
|
from pilot.openapi.api_view_model import (
|
||||||
Result,
|
Result,
|
||||||
@ -34,7 +33,8 @@ from pilot.scene.chat_db.data_loader import DbDataLoader
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
CHAT_FACTORY = ChatFactory()
|
CHAT_FACTORY = ChatFactory()
|
||||||
logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
|
||||||
|
@ -2,18 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
import logging
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict, TypeVar, Union
|
from typing import Any, Dict, TypeVar, Union
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.model.base import ModelOutput
|
from pilot.model.base import ModelOutput
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
ResponseTye = Union[str, bytes, ModelOutput]
|
ResponseTye = Union[str, bytes, ModelOutput]
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.component import ComponentType
|
from pilot.component import ComponentType
|
||||||
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
from pilot.memory.chat_history.base import BaseChatHistoryMemory
|
||||||
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
from pilot.memory.chat_history.duckdb_history import DuckdbHistoryMemory
|
||||||
@ -14,10 +14,10 @@ from pilot.memory.chat_history.mem_history import MemHistoryMemory
|
|||||||
from pilot.prompts.prompt_new import PromptTemplate
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
from pilot.scene.message import OnceConversation
|
from pilot.scene.message import OnceConversation
|
||||||
from pilot.utils import build_logger, get_or_create_event_loop
|
from pilot.utils import get_or_create_event_loop
|
||||||
from pydantic import Extra
|
from pydantic import Extra
|
||||||
|
|
||||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
logger = logging.getLogger(__name__)
|
||||||
headers = {"User-Agent": "dbgpt Client"}
|
headers = {"User-Agent": "dbgpt Client"}
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
import logging
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
from pilot.scene.chat_dashboard.data_preparation.report_schma import ValueItem
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DashboardDataLoader:
|
class DashboardDataLoader:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from typing import NamedTuple, List
|
from typing import NamedTuple, List
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
|
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ class ChartItem(NamedTuple):
|
|||||||
showcase: str
|
showcase: str
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "ChatDashboardOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatDashboardOutputParser(BaseOutputParser):
|
class ChatDashboardOutputParser(BaseOutputParser):
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple, List
|
from typing import Dict, NamedTuple, List
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -17,7 +13,7 @@ class ExcelAnalyzeResponse(NamedTuple):
|
|||||||
display: str
|
display: str
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatExcelOutputParser(BaseOutputParser):
|
class ChatExcelOutputParser(BaseOutputParser):
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, NamedTuple, List
|
from typing import Dict, NamedTuple, List
|
||||||
import pandas as pd
|
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -17,7 +13,7 @@ class ExcelResponse(NamedTuple):
|
|||||||
plans: List
|
plans: List
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("chat_excel", LOGDIR + "ChatExcel.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LearningExcelOutputParser(BaseOutputParser):
|
class LearningExcelOutputParser(BaseOutputParser):
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
from pilot.utils import build_logger
|
import logging
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ class SqlAction(NamedTuple):
|
|||||||
thoughts: Dict
|
thoughts: Dict
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DbChatOutputParser(BaseOutputParser):
|
class DbChatOutputParser(BaseOutputParser):
|
||||||
|
@ -1,8 +1,4 @@
|
|||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
|
||||||
|
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
from pilot.utils import build_logger
|
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PluginAction(NamedTuple):
|
class PluginAction(NamedTuple):
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
from pilot.utils import build_logger
|
import logging
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
|
||||||
|
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
@ -60,17 +60,24 @@ class ChatKnowledge(BaseChat):
|
|||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = self.generate_input_values()
|
||||||
async for output in super().stream_call():
|
|
||||||
# Source of knowledge file
|
# Source of knowledge file
|
||||||
relations = input_values.get("relations")
|
relations = input_values.get("relations")
|
||||||
|
last_output = None
|
||||||
|
async for output in super().stream_call():
|
||||||
|
last_output = output
|
||||||
|
yield output
|
||||||
|
|
||||||
if (
|
if (
|
||||||
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
CFG.KNOWLEDGE_CHAT_SHOW_RELATIONS
|
||||||
|
and last_output
|
||||||
and type(relations) == list
|
and type(relations) == list
|
||||||
and len(relations) > 0
|
and len(relations) > 0
|
||||||
and hasattr(output, "text")
|
and hasattr(last_output, "text")
|
||||||
):
|
):
|
||||||
output.text = output.text + "\trelations:" + ",".join(relations)
|
last_output.text = (
|
||||||
yield output
|
last_output.text + "\n\nrelations:\n\n" + ",".join(relations)
|
||||||
|
)
|
||||||
|
yield last_output
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
if self.space_context:
|
if self.space_context:
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from pilot.utils import build_logger
|
import logging
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
from pilot.utils import build_logger
|
import logging
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.model_config import LOGDIR
|
|
||||||
|
|
||||||
|
|
||||||
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NormalChatOutputParser(BaseOutputParser):
|
class NormalChatOutputParser(BaseOutputParser):
|
||||||
|
@ -108,7 +108,7 @@ class WebWerverParameters(BaseParameters):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_level: Optional[str] = field(
|
log_level: Optional[str] = field(
|
||||||
default="INFO",
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Logging level",
|
"help": "Logging level",
|
||||||
"valid_values": [
|
"valid_values": [
|
||||||
|
@ -33,7 +33,11 @@ from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route
|
|||||||
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
from pilot.openapi.api_v1.feedback.api_fb_v1 import router as api_fb_v1
|
||||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||||
from pilot.model.cluster import initialize_worker_manager_in_client
|
from pilot.model.cluster import initialize_worker_manager_in_client
|
||||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
from pilot.utils.utils import (
|
||||||
|
setup_logging,
|
||||||
|
_get_logging_level,
|
||||||
|
logging_str_to_uvicorn_level,
|
||||||
|
)
|
||||||
|
|
||||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||||
|
|
||||||
@ -111,7 +115,11 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
|||||||
)
|
)
|
||||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||||
|
|
||||||
setup_logging(logging_level=param.log_level)
|
if not param.log_level:
|
||||||
|
param.log_level = _get_logging_level()
|
||||||
|
setup_logging(
|
||||||
|
"pilot", logging_level=param.log_level, logger_filename="dbgpt_webserver.log"
|
||||||
|
)
|
||||||
# Before start
|
# Before start
|
||||||
system_app.before_start()
|
system_app.before_start()
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import click
|
import click
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import functools
|
||||||
|
|
||||||
from pilot.configs.model_config import DATASETS_DIR
|
from pilot.configs.model_config import DATASETS_DIR
|
||||||
|
|
||||||
API_ADDRESS: str = "http://127.0.0.1:5000"
|
_DEFAULT_API_ADDRESS: str = "http://127.0.0.1:5000"
|
||||||
|
API_ADDRESS: str = _DEFAULT_API_ADDRESS
|
||||||
|
|
||||||
logger = logging.getLogger("dbgpt_cli")
|
logger = logging.getLogger("dbgpt_cli")
|
||||||
|
|
||||||
@ -20,33 +23,44 @@ logger = logging.getLogger("dbgpt_cli")
|
|||||||
def knowledge_cli_group(address: str):
|
def knowledge_cli_group(address: str):
|
||||||
"""Knowledge command line tool"""
|
"""Knowledge command line tool"""
|
||||||
global API_ADDRESS
|
global API_ADDRESS
|
||||||
|
if address == _DEFAULT_API_ADDRESS:
|
||||||
|
address = os.getenv("API_ADDRESS", _DEFAULT_API_ADDRESS)
|
||||||
API_ADDRESS = address
|
API_ADDRESS = address
|
||||||
|
|
||||||
|
|
||||||
@knowledge_cli_group.command()
|
def add_knowledge_options(func):
|
||||||
@click.option(
|
@click.option(
|
||||||
"--vector_name",
|
"--space_name",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default="default",
|
default="default",
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Your vector store name",
|
help="Your knowledge space name",
|
||||||
)
|
)
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_cli_group.command()
|
||||||
|
@add_knowledge_options
|
||||||
@click.option(
|
@click.option(
|
||||||
"--vector_store_type",
|
"--vector_store_type",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default="Chroma",
|
default="Chroma",
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Vector store type",
|
help="Vector store type.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--local_doc_dir",
|
"--local_doc_path",
|
||||||
required=False,
|
required=False,
|
||||||
type=str,
|
type=str,
|
||||||
default=DATASETS_DIR,
|
default=DATASETS_DIR,
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Your document directory",
|
help="Your document directory or document file path.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--skip_wrong_doc",
|
"--skip_wrong_doc",
|
||||||
@ -54,31 +68,165 @@ def knowledge_cli_group(address: str):
|
|||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
show_default=True,
|
help="Skip wrong document.",
|
||||||
help="Skip wrong document",
|
)
|
||||||
|
@click.option(
|
||||||
|
"--overwrite",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Overwrite existing document(they has same name).",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--max_workers",
|
"--max_workers",
|
||||||
required=False,
|
required=False,
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="The maximum number of threads that can be used to upload document",
|
help="The maximum number of threads that can be used to upload document.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--pre_separator",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Preseparator, this separator is used for pre-splitting before the document is "
|
||||||
|
"actually split by the text splitter. Preseparator are not included in the vectorized text. ",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--separator",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="This is the document separator. Currently, only one separator is supported.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--chunk_size",
|
||||||
|
required=False,
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum size of chunks to split.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--chunk_overlap",
|
||||||
|
required=False,
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Overlap in characters between chunks.",
|
||||||
)
|
)
|
||||||
def load(
|
def load(
|
||||||
vector_name: str,
|
space_name: str,
|
||||||
vector_store_type: str,
|
vector_store_type: str,
|
||||||
local_doc_dir: str,
|
local_doc_path: str,
|
||||||
skip_wrong_doc: bool,
|
skip_wrong_doc: bool,
|
||||||
|
overwrite: bool,
|
||||||
max_workers: int,
|
max_workers: int,
|
||||||
|
pre_separator: str,
|
||||||
|
separator: str,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
):
|
):
|
||||||
"""Load your local knowledge to DB-GPT"""
|
"""Load your local knowledge to DB-GPT"""
|
||||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
|
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
|
||||||
|
|
||||||
knowledge_init(
|
knowledge_init(
|
||||||
API_ADDRESS,
|
API_ADDRESS,
|
||||||
vector_name,
|
space_name,
|
||||||
vector_store_type,
|
vector_store_type,
|
||||||
local_doc_dir,
|
local_doc_path,
|
||||||
skip_wrong_doc,
|
skip_wrong_doc,
|
||||||
|
overwrite,
|
||||||
max_workers,
|
max_workers,
|
||||||
|
pre_separator,
|
||||||
|
separator,
|
||||||
|
chunk_size,
|
||||||
|
chunk_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_cli_group.command()
|
||||||
|
@add_knowledge_options
|
||||||
|
@click.option(
|
||||||
|
"--doc_name",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The document name you want to delete. If doc_name is None, this command will delete the whole space.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-y",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Confirm your choice",
|
||||||
|
)
|
||||||
|
def delete(space_name: str, doc_name: str, y: bool):
|
||||||
|
"""Delete your knowledge space or document in space"""
|
||||||
|
from pilot.server.knowledge._cli.knowledge_client import knowledge_delete
|
||||||
|
|
||||||
|
knowledge_delete(API_ADDRESS, space_name, doc_name, confirm=y)
|
||||||
|
|
||||||
|
|
||||||
|
@knowledge_cli_group.command()
|
||||||
|
@click.option(
|
||||||
|
"--space_name",
|
||||||
|
required=False,
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Your knowledge space name. If None, list all spaces",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--doc_id",
|
||||||
|
required=False,
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
show_default=True,
|
||||||
|
help="Your document id in knowledge space. If Not None, list all chunks in current document",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--page",
|
||||||
|
required=False,
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
show_default=True,
|
||||||
|
help="The page for every query",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--page_size",
|
||||||
|
required=False,
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
show_default=True,
|
||||||
|
help="The page size for every query",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--show_content",
|
||||||
|
required=False,
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
is_flag=True,
|
||||||
|
help="Query the document content of chunks",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
required=False,
|
||||||
|
type=click.Choice(["text", "html", "csv", "latex", "json"]),
|
||||||
|
default="text",
|
||||||
|
help="The output format",
|
||||||
|
)
|
||||||
|
def list(
|
||||||
|
space_name: str,
|
||||||
|
doc_id: int,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
show_content: bool,
|
||||||
|
output: str,
|
||||||
|
):
|
||||||
|
"""List knowledge space"""
|
||||||
|
from pilot.server.knowledge._cli.knowledge_client import knowledge_list
|
||||||
|
|
||||||
|
knowledge_list(
|
||||||
|
API_ADDRESS, space_name, page, page_size, doc_id, show_content, output
|
||||||
)
|
)
|
||||||
|
@ -62,6 +62,9 @@ class KnowledgeApiClient(ApiClient):
|
|||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def space_delete(self, request: KnowledgeSpaceRequest):
|
||||||
|
return self._post("/knowledge/space/delete", data=request)
|
||||||
|
|
||||||
def space_list(self, request: KnowledgeSpaceRequest):
|
def space_list(self, request: KnowledgeSpaceRequest):
|
||||||
return self._post("/knowledge/space/list", data=request)
|
return self._post("/knowledge/space/list", data=request)
|
||||||
|
|
||||||
@ -69,6 +72,10 @@ class KnowledgeApiClient(ApiClient):
|
|||||||
url = f"/knowledge/{space_name}/document/add"
|
url = f"/knowledge/{space_name}/document/add"
|
||||||
return self._post(url, data=request)
|
return self._post(url, data=request)
|
||||||
|
|
||||||
|
def document_delete(self, space_name: str, request: KnowledgeDocumentRequest):
|
||||||
|
url = f"/knowledge/{space_name}/document/delete"
|
||||||
|
return self._post(url, data=request)
|
||||||
|
|
||||||
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
|
def document_list(self, space_name: str, query_request: DocumentQueryRequest):
|
||||||
url = f"/knowledge/{space_name}/document/list"
|
url = f"/knowledge/{space_name}/document/list"
|
||||||
return self._post(url, data=query_request)
|
return self._post(url, data=query_request)
|
||||||
@ -97,15 +104,20 @@ class KnowledgeApiClient(ApiClient):
|
|||||||
|
|
||||||
def knowledge_init(
|
def knowledge_init(
|
||||||
api_address: str,
|
api_address: str,
|
||||||
vector_name: str,
|
space_name: str,
|
||||||
vector_store_type: str,
|
vector_store_type: str,
|
||||||
local_doc_dir: str,
|
local_doc_path: str,
|
||||||
skip_wrong_doc: bool,
|
skip_wrong_doc: bool,
|
||||||
max_workers: int = None,
|
overwrite: bool,
|
||||||
|
max_workers: int,
|
||||||
|
pre_separator: str,
|
||||||
|
separator: str,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
):
|
):
|
||||||
client = KnowledgeApiClient(api_address)
|
client = KnowledgeApiClient(api_address)
|
||||||
space = KnowledgeSpaceRequest()
|
space = KnowledgeSpaceRequest()
|
||||||
space.name = vector_name
|
space.name = space_name
|
||||||
space.desc = "DB-GPT cli"
|
space.desc = "DB-GPT cli"
|
||||||
space.vector_type = vector_store_type
|
space.vector_type = vector_store_type
|
||||||
space.owner = "DB-GPT"
|
space.owner = "DB-GPT"
|
||||||
@ -124,24 +136,260 @@ def knowledge_init(
|
|||||||
def upload(filename: str):
|
def upload(filename: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Begin upload document: {filename} to {space.name}")
|
logger.info(f"Begin upload document: {filename} to {space.name}")
|
||||||
|
doc_id = None
|
||||||
|
try:
|
||||||
doc_id = client.document_upload(
|
doc_id = client.document_upload(
|
||||||
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||||
)
|
)
|
||||||
client.document_sync(space.name, DocumentSyncRequest(doc_ids=[doc_id]))
|
except Exception as ex:
|
||||||
|
if overwrite and "have already named" in str(ex):
|
||||||
|
logger.warn(
|
||||||
|
f"Document {filename} already exist in space {space.name}, overwrite it"
|
||||||
|
)
|
||||||
|
client.document_delete(
|
||||||
|
space.name, KnowledgeDocumentRequest(doc_name=filename)
|
||||||
|
)
|
||||||
|
doc_id = client.document_upload(
|
||||||
|
space.name, filename, KnowledgeType.DOCUMENT.value, filename
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ex
|
||||||
|
sync_req = DocumentSyncRequest(doc_ids=[doc_id])
|
||||||
|
if pre_separator:
|
||||||
|
sync_req.pre_separator = pre_separator
|
||||||
|
if separator:
|
||||||
|
sync_req.separators = [separator]
|
||||||
|
if chunk_size:
|
||||||
|
sync_req.chunk_size = chunk_size
|
||||||
|
if chunk_overlap:
|
||||||
|
sync_req.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
client.document_sync(space.name, sync_req)
|
||||||
|
return doc_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if skip_wrong_doc:
|
if skip_wrong_doc:
|
||||||
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
logger.warn(f"Upload {filename} to {space.name} failed: {str(e)}")
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if not os.path.exists(local_doc_path):
|
||||||
|
raise Exception(f"{local_doc_path} not exists")
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
tasks = []
|
tasks = []
|
||||||
for root, _, files in os.walk(local_doc_dir, topdown=False):
|
file_names = []
|
||||||
|
if os.path.isdir(local_doc_path):
|
||||||
|
for root, _, files in os.walk(local_doc_path, topdown=False):
|
||||||
for file in files:
|
for file in files:
|
||||||
filename = os.path.join(root, file)
|
file_names.append(os.path.join(root, file))
|
||||||
tasks.append(pool.submit(upload, filename))
|
else:
|
||||||
|
# Single file
|
||||||
|
file_names.append(local_doc_path)
|
||||||
|
|
||||||
|
[tasks.append(pool.submit(upload, filename)) for filename in file_names]
|
||||||
|
|
||||||
doc_ids = [r.result() for r in as_completed(tasks)]
|
doc_ids = [r.result() for r in as_completed(tasks)]
|
||||||
doc_ids = list(filter(lambda x: x, doc_ids))
|
doc_ids = list(filter(lambda x: x, doc_ids))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
logger.warn("Warning: no document to sync")
|
logger.warn("Warning: no document to sync")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
|
||||||
|
class _KnowledgeVisualizer:
|
||||||
|
def __init__(self, api_address: str, out_format: str):
|
||||||
|
self.client = KnowledgeApiClient(api_address)
|
||||||
|
self.out_format = out_format
|
||||||
|
self.out_kwargs = {}
|
||||||
|
if out_format == "json":
|
||||||
|
self.out_kwargs["ensure_ascii"] = False
|
||||||
|
|
||||||
|
def print_table(self, table):
|
||||||
|
print(table.get_formatted_string(out_format=self.out_format, **self.out_kwargs))
|
||||||
|
|
||||||
|
def list_spaces(self):
|
||||||
|
spaces = self.client.space_list(KnowledgeSpaceRequest())
|
||||||
|
table = PrettyTable(
|
||||||
|
["Space ID", "Space Name", "Vector Type", "Owner", "Description"],
|
||||||
|
title="All knowledge spaces",
|
||||||
|
)
|
||||||
|
for sp in spaces:
|
||||||
|
context = sp.get("context")
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
sp.get("id"),
|
||||||
|
sp.get("name"),
|
||||||
|
sp.get("vector_type"),
|
||||||
|
sp.get("owner"),
|
||||||
|
sp.get("desc"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.print_table(table)
|
||||||
|
|
||||||
|
def list_documents(self, space_name: str, page: int, page_size: int):
|
||||||
|
space_data = self.client.document_list(
|
||||||
|
space_name, DocumentQueryRequest(page=page, page_size=page_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
space_table = PrettyTable(
|
||||||
|
[
|
||||||
|
"Space Name",
|
||||||
|
"Total Documents",
|
||||||
|
"Current Page",
|
||||||
|
"Current Size",
|
||||||
|
"Page Size",
|
||||||
|
],
|
||||||
|
title=f"Space {space_name} description",
|
||||||
|
)
|
||||||
|
space_table.add_row(
|
||||||
|
[space_name, space_data["total"], page, len(space_data["data"]), page_size]
|
||||||
|
)
|
||||||
|
|
||||||
|
table = PrettyTable(
|
||||||
|
[
|
||||||
|
"Space Name",
|
||||||
|
"Document ID",
|
||||||
|
"Document Name",
|
||||||
|
"Type",
|
||||||
|
"Chunks",
|
||||||
|
"Last Sync",
|
||||||
|
"Status",
|
||||||
|
"Result",
|
||||||
|
],
|
||||||
|
title=f"Documents of space {space_name}",
|
||||||
|
)
|
||||||
|
for doc in space_data["data"]:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
space_name,
|
||||||
|
doc.get("id"),
|
||||||
|
doc.get("doc_name"),
|
||||||
|
doc.get("doc_type"),
|
||||||
|
doc.get("chunk_size"),
|
||||||
|
doc.get("last_sync"),
|
||||||
|
doc.get("status"),
|
||||||
|
doc.get("result"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.out_format == "text":
|
||||||
|
self.print_table(space_table)
|
||||||
|
print("")
|
||||||
|
self.print_table(table)
|
||||||
|
|
||||||
|
def list_chunks(
|
||||||
|
self,
|
||||||
|
space_name: str,
|
||||||
|
doc_id: int,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
show_content: bool,
|
||||||
|
):
|
||||||
|
doc_data = self.client.chunk_list(
|
||||||
|
space_name,
|
||||||
|
ChunkQueryRequest(document_id=doc_id, page=page, page_size=page_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_table = PrettyTable(
|
||||||
|
[
|
||||||
|
"Space Name",
|
||||||
|
"Document ID",
|
||||||
|
"Total Chunks",
|
||||||
|
"Current Page",
|
||||||
|
"Current Size",
|
||||||
|
"Page Size",
|
||||||
|
],
|
||||||
|
title=f"Document {doc_id} in {space_name} description",
|
||||||
|
)
|
||||||
|
doc_table.add_row(
|
||||||
|
[
|
||||||
|
space_name,
|
||||||
|
doc_id,
|
||||||
|
doc_data["total"],
|
||||||
|
page,
|
||||||
|
len(doc_data["data"]),
|
||||||
|
page_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
table = PrettyTable(
|
||||||
|
["Space Name", "Document ID", "Document Name", "Content", "Meta Data"],
|
||||||
|
title=f"chunks of document id {doc_id} in space {space_name}",
|
||||||
|
)
|
||||||
|
for chunk in doc_data["data"]:
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
space_name,
|
||||||
|
doc_id,
|
||||||
|
chunk.get("doc_name"),
|
||||||
|
chunk.get("content") if show_content else "[Hidden]",
|
||||||
|
chunk.get("meta_info"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.out_format == "text":
|
||||||
|
self.print_table(doc_table)
|
||||||
|
print("")
|
||||||
|
self.print_table(table)
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_list(
|
||||||
|
api_address: str,
|
||||||
|
space_name: str,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
doc_id: int,
|
||||||
|
show_content: bool,
|
||||||
|
out_format: str,
|
||||||
|
):
|
||||||
|
visualizer = _KnowledgeVisualizer(api_address, out_format)
|
||||||
|
if not space_name:
|
||||||
|
visualizer.list_spaces()
|
||||||
|
elif not doc_id:
|
||||||
|
visualizer.list_documents(space_name, page, page_size)
|
||||||
|
else:
|
||||||
|
visualizer.list_chunks(space_name, doc_id, page, page_size, show_content)
|
||||||
|
|
||||||
|
|
||||||
|
def knowledge_delete(
|
||||||
|
api_address: str, space_name: str, doc_name: str, confirm: bool = False
|
||||||
|
):
|
||||||
|
client = KnowledgeApiClient(api_address)
|
||||||
|
space = KnowledgeSpaceRequest()
|
||||||
|
space.name = space_name
|
||||||
|
space_list = client.space_list(KnowledgeSpaceRequest(name=space.name))
|
||||||
|
if not space_list:
|
||||||
|
raise Exception(f"No knowledge space name {space_name}")
|
||||||
|
|
||||||
|
if not doc_name:
|
||||||
|
if not confirm:
|
||||||
|
# Confirm by user
|
||||||
|
user_input = (
|
||||||
|
input(
|
||||||
|
f"Are you sure you want to delete the whole knowledge space {space_name}? Type 'yes' to confirm: "
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
if user_input != "yes":
|
||||||
|
logger.warn("Delete operation cancelled.")
|
||||||
|
return
|
||||||
|
client.space_delete(space)
|
||||||
|
logger.info("Delete the whole knowledge space successfully!")
|
||||||
|
else:
|
||||||
|
if not confirm:
|
||||||
|
# Confirm by user
|
||||||
|
user_input = (
|
||||||
|
input(
|
||||||
|
f"Are you sure you want to delete the doucment {doc_name} in knowledge space {space_name}? Type 'yes' to confirm: "
|
||||||
|
)
|
||||||
|
.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
|
if user_input != "yes":
|
||||||
|
logger.warn("Delete operation cancelled.")
|
||||||
|
return
|
||||||
|
client.document_delete(space_name, KnowledgeDocumentRequest(doc_name=doc_name))
|
||||||
|
logger.info(
|
||||||
|
f"Delete the doucment {doc_name} in knowledge space {space_name} successfully!"
|
||||||
|
)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, File, UploadFile, Form
|
from fastapi import APIRouter, File, UploadFile, Form
|
||||||
|
|
||||||
@ -27,6 +28,8 @@ from pilot.server.knowledge.request.request import (
|
|||||||
|
|
||||||
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
from pilot.server.knowledge.request.request import KnowledgeSpaceRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -159,10 +162,10 @@ async def document_upload(
|
|||||||
|
|
||||||
@router.post("/knowledge/{space_name}/document/sync")
|
@router.post("/knowledge/{space_name}/document/sync")
|
||||||
def document_sync(space_name: str, request: DocumentSyncRequest):
|
def document_sync(space_name: str, request: DocumentSyncRequest):
|
||||||
print(f"Received params: {space_name}, {request}")
|
logger.info(f"Received params: {space_name}, {request}")
|
||||||
try:
|
try:
|
||||||
knowledge_space_service.sync_knowledge_document(
|
knowledge_space_service.sync_knowledge_document(
|
||||||
space_name=space_name, doc_ids=request.doc_ids
|
space_name=space_name, sync_request=request
|
||||||
)
|
)
|
||||||
return Result.succ([])
|
return Result.succ([])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
@ -54,10 +54,25 @@ class DocumentQueryRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class DocumentSyncRequest(BaseModel):
|
class DocumentSyncRequest(BaseModel):
|
||||||
"""doc_ids: doc ids"""
|
"""Sync request"""
|
||||||
|
|
||||||
|
"""doc_ids: doc ids"""
|
||||||
doc_ids: List
|
doc_ids: List
|
||||||
|
|
||||||
|
"""Preseparator, this separator is used for pre-splitting before the document is actually split by the text splitter.
|
||||||
|
Preseparator are not included in the vectorized text.
|
||||||
|
"""
|
||||||
|
pre_separator: Optional[str] = None
|
||||||
|
|
||||||
|
"""Custom separators"""
|
||||||
|
separators: Optional[List[str]] = None
|
||||||
|
|
||||||
|
"""Custom chunk size"""
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
|
||||||
|
"""Custom chunk overlap"""
|
||||||
|
chunk_overlap: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class ChunkQueryRequest(BaseModel):
|
class ChunkQueryRequest(BaseModel):
|
||||||
"""id: id"""
|
"""id: id"""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import threading
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pilot.vector_store.connector import VectorStoreConnector
|
from pilot.vector_store.connector import VectorStoreConnector
|
||||||
@ -12,7 +12,6 @@ from pilot.configs.model_config import (
|
|||||||
from pilot.component import ComponentType
|
from pilot.component import ComponentType
|
||||||
from pilot.utils.executor_utils import ExecutorFactory
|
from pilot.utils.executor_utils import ExecutorFactory
|
||||||
|
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.server.knowledge.chunk_db import (
|
from pilot.server.knowledge.chunk_db import (
|
||||||
DocumentChunkEntity,
|
DocumentChunkEntity,
|
||||||
DocumentChunkDao,
|
DocumentChunkDao,
|
||||||
@ -31,6 +30,7 @@ from pilot.server.knowledge.request.request import (
|
|||||||
DocumentQueryRequest,
|
DocumentQueryRequest,
|
||||||
ChunkQueryRequest,
|
ChunkQueryRequest,
|
||||||
SpaceArgumentRequest,
|
SpaceArgumentRequest,
|
||||||
|
DocumentSyncRequest,
|
||||||
)
|
)
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -44,6 +44,7 @@ knowledge_space_dao = KnowledgeSpaceDao()
|
|||||||
knowledge_document_dao = KnowledgeDocumentDao()
|
knowledge_document_dao = KnowledgeDocumentDao()
|
||||||
document_chunk_dao = DocumentChunkDao()
|
document_chunk_dao = DocumentChunkDao()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
@ -107,7 +108,6 @@ class KnowledgeService:
|
|||||||
res.owner = space.owner
|
res.owner = space.owner
|
||||||
res.gmt_created = space.gmt_created
|
res.gmt_created = space.gmt_created
|
||||||
res.gmt_modified = space.gmt_modified
|
res.gmt_modified = space.gmt_modified
|
||||||
res.owner = space.owner
|
|
||||||
res.context = space.context
|
res.context = space.context
|
||||||
query = KnowledgeDocumentEntity(space=space.name)
|
query = KnowledgeDocumentEntity(space=space.name)
|
||||||
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
doc_count = knowledge_document_dao.get_knowledge_documents_count(query)
|
||||||
@ -155,9 +155,10 @@ class KnowledgeService:
|
|||||||
|
|
||||||
"""sync knowledge document chunk into vector store"""
|
"""sync knowledge document chunk into vector store"""
|
||||||
|
|
||||||
def sync_knowledge_document(self, space_name, doc_ids):
|
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||||
|
from pilot.embedding_engine.pre_text_splitter import PreTextSplitter
|
||||||
from langchain.text_splitter import (
|
from langchain.text_splitter import (
|
||||||
RecursiveCharacterTextSplitter,
|
RecursiveCharacterTextSplitter,
|
||||||
SpacyTextSplitter,
|
SpacyTextSplitter,
|
||||||
@ -165,6 +166,7 @@ class KnowledgeService:
|
|||||||
|
|
||||||
# import langchain is very very slow!!!
|
# import langchain is very very slow!!!
|
||||||
|
|
||||||
|
doc_ids = sync_request.doc_ids
|
||||||
for doc_id in doc_ids:
|
for doc_id in doc_ids:
|
||||||
query = KnowledgeDocumentEntity(
|
query = KnowledgeDocumentEntity(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
@ -190,24 +192,43 @@ class KnowledgeService:
|
|||||||
if space_context is None
|
if space_context is None
|
||||||
else int(space_context["embedding"]["chunk_overlap"])
|
else int(space_context["embedding"]["chunk_overlap"])
|
||||||
)
|
)
|
||||||
|
if sync_request.chunk_size:
|
||||||
|
chunk_size = sync_request.chunk_size
|
||||||
|
if sync_request.chunk_overlap:
|
||||||
|
chunk_overlap = sync_request.chunk_overlap
|
||||||
|
separators = sync_request.separators or None
|
||||||
if CFG.LANGUAGE == "en":
|
if CFG.LANGUAGE == "en":
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
separators=separators,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
length_function=len,
|
length_function=len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if separators and len(separators) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"SpacyTextSplitter do not support multiple separators"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
|
separator = "\n\n" if not separators else separators[0]
|
||||||
text_splitter = SpacyTextSplitter(
|
text_splitter = SpacyTextSplitter(
|
||||||
|
separator=separator,
|
||||||
pipeline="zh_core_web_sm",
|
pipeline="zh_core_web_sm",
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
separators=separators,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
)
|
)
|
||||||
|
if sync_request.pre_separator:
|
||||||
|
logger.info(f"Use preseparator, {sync_request.pre_separator}")
|
||||||
|
text_splitter = PreTextSplitter(
|
||||||
|
pre_separator=sync_request.pre_separator,
|
||||||
|
text_splitter_impl=text_splitter,
|
||||||
|
)
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""ElevenLabs speech module"""
|
"""ElevenLabs speech module"""
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -8,6 +8,8 @@ from pilot.speech.base import VoiceBase
|
|||||||
|
|
||||||
PLACEHOLDERS = {"your-voice-id"}
|
PLACEHOLDERS = {"your-voice-id"}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ElevenLabsSpeech(VoiceBase):
|
class ElevenLabsSpeech(VoiceBase):
|
||||||
"""ElevenLabs speech class"""
|
"""ElevenLabs speech class"""
|
||||||
@ -68,7 +70,6 @@ class ElevenLabsSpeech(VoiceBase):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if the request was successful, False otherwise
|
bool: True if the request was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
from pilot.logs import logger
|
|
||||||
from playsound import playsound
|
from playsound import playsound
|
||||||
|
|
||||||
tts_url = (
|
tts_url = (
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
from pilot.common.schema import DBType
|
from pilot.common.schema import DBType
|
||||||
from pilot.component import SystemApp
|
from pilot.component import SystemApp
|
||||||
@ -7,16 +8,14 @@ from pilot.configs.config import Config
|
|||||||
from pilot.configs.model_config import (
|
from pilot.configs.model_config import (
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||||
EMBEDDING_MODEL_CONFIG,
|
EMBEDDING_MODEL_CONFIG,
|
||||||
LOGDIR,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.chat_factory import ChatFactory
|
from pilot.scene.chat_factory import ChatFactory
|
||||||
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
from pilot.summary.rdbms_db_summary import RdbmsSummary
|
||||||
from pilot.utils import build_logger
|
|
||||||
|
|
||||||
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
chat_factory = ChatFactory()
|
chat_factory = ChatFactory()
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
get_gpu_memory,
|
get_gpu_memory,
|
||||||
build_logger,
|
|
||||||
StreamToLogger,
|
StreamToLogger,
|
||||||
disable_torch_init,
|
disable_torch_init,
|
||||||
pretty_print_semaphore,
|
pretty_print_semaphore,
|
||||||
|
@ -5,6 +5,8 @@ from dataclasses import is_dataclass, asdict
|
|||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
def _extract_dataclass_from_generic(type_hint: Type[T]) -> Union[Type[T], None]:
|
||||||
import typing_inspect
|
import typing_inspect
|
||||||
@ -21,7 +23,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
|||||||
raise TypeError("Return type must be annotated in the decorated function.")
|
raise TypeError("Return type must be annotated in the decorated function.")
|
||||||
|
|
||||||
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
actual_dataclass = _extract_dataclass_from_generic(return_type)
|
||||||
logging.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
logger.debug(f"return_type: {return_type}, actual_dataclass: {actual_dataclass}")
|
||||||
if not actual_dataclass:
|
if not actual_dataclass:
|
||||||
actual_dataclass = return_type
|
actual_dataclass = return_type
|
||||||
sig = signature(func)
|
sig = signature(func)
|
||||||
@ -55,7 +57,7 @@ def _build_request(self, func, path, method, *args, **kwargs):
|
|||||||
else: # For GET, DELETE, etc.
|
else: # For GET, DELETE, etc.
|
||||||
request_params["params"] = request_data
|
request_params["params"] = request_data
|
||||||
|
|
||||||
logging.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
logger.debug(f"request_params: {request_params}, args: {args}, kwargs: {kwargs}")
|
||||||
return return_type, actual_dataclass, request_params
|
return return_type, actual_dataclass, request_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def _get_logging_level() -> str:
|
|||||||
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
return os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(logging_level=None, logger_name: str = None):
|
def setup_logging_level(logging_level=None, logger_name: str = None):
|
||||||
if not logging_level:
|
if not logging_level:
|
||||||
logging_level = _get_logging_level()
|
logging_level = _get_logging_level()
|
||||||
if type(logging_level) is str:
|
if type(logging_level) is str:
|
||||||
@ -32,6 +32,19 @@ def setup_logging(logging_level=None, logger_name: str = None):
|
|||||||
logging.basicConfig(level=logging_level, encoding="utf-8")
|
logging.basicConfig(level=logging_level, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(logger_name: str, logging_level=None, logger_filename: str = None):
|
||||||
|
if not logging_level:
|
||||||
|
logging_level = _get_logging_level()
|
||||||
|
logger = _build_logger(logger_name, logging_level, logger_filename)
|
||||||
|
try:
|
||||||
|
import coloredlogs
|
||||||
|
|
||||||
|
color_level = logging_level if logging_level else "INFO"
|
||||||
|
coloredlogs.install(level=color_level, logger=logger)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory(max_gpus=None):
|
def get_gpu_memory(max_gpus=None):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -53,7 +66,7 @@ def get_gpu_memory(max_gpus=None):
|
|||||||
return gpu_memory
|
return gpu_memory
|
||||||
|
|
||||||
|
|
||||||
def build_logger(logger_name, logger_filename):
|
def _build_logger(logger_name, logging_level=None, logger_filename: str = None):
|
||||||
global handler
|
global handler
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
@ -63,7 +76,7 @@ def build_logger(logger_name, logger_filename):
|
|||||||
|
|
||||||
# Set the format of root handlers
|
# Set the format of root handlers
|
||||||
if not logging.getLogger().handlers:
|
if not logging.getLogger().handlers:
|
||||||
setup_logging()
|
setup_logging_level(logging_level=logging_level)
|
||||||
logging.getLogger().handlers[0].setFormatter(formatter)
|
logging.getLogger().handlers[0].setFormatter(formatter)
|
||||||
|
|
||||||
# Redirect stdout and stderr to loggers
|
# Redirect stdout and stderr to loggers
|
||||||
@ -78,7 +91,7 @@ def build_logger(logger_name, logger_filename):
|
|||||||
# sys.stderr = sl
|
# sys.stderr = sl
|
||||||
|
|
||||||
# Add a file handler for all loggers
|
# Add a file handler for all loggers
|
||||||
if handler is None:
|
if handler is None and logger_filename:
|
||||||
os.makedirs(LOGDIR, exist_ok=True)
|
os.makedirs(LOGDIR, exist_ok=True)
|
||||||
filename = os.path.join(LOGDIR, logger_filename)
|
filename = os.path.join(LOGDIR, logger_filename)
|
||||||
handler = logging.handlers.TimedRotatingFileHandler(
|
handler = logging.handlers.TimedRotatingFileHandler(
|
||||||
@ -89,11 +102,9 @@ def build_logger(logger_name, logger_filename):
|
|||||||
for name, item in logging.root.manager.loggerDict.items():
|
for name, item in logging.root.manager.loggerDict.items():
|
||||||
if isinstance(item, logging.Logger):
|
if isinstance(item, logging.Logger):
|
||||||
item.addHandler(handler)
|
item.addHandler(handler)
|
||||||
setup_logging()
|
|
||||||
|
|
||||||
# Get logger
|
# Get logger
|
||||||
logger = logging.getLogger(logger_name)
|
logger = logging.getLogger(logger_name)
|
||||||
setup_logging(logger_name=logger_name)
|
setup_logging_level(logging_level=logging_level, logger_name=logger_name)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChromaStore(VectorStoreBase):
|
class ChromaStore(VectorStoreBase):
|
||||||
"""chroma database"""
|
"""chroma database"""
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import logging
|
||||||
from typing import Any, Iterable, List, Optional, Tuple
|
from typing import Any, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from pymilvus import Collection, DataType, connections, utility
|
from pymilvus import Collection, DataType, connections, utility
|
||||||
|
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MilvusStore(VectorStoreBase):
|
class MilvusStore(VectorStoreBase):
|
||||||
"""Milvus database"""
|
"""Milvus database"""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import weaviate
|
import weaviate
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
from langchain.vectorstores import Weaviate
|
from langchain.vectorstores import Weaviate
|
||||||
@ -7,9 +8,9 @@ from weaviate.exceptions import WeaviateBaseError
|
|||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
from pilot.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||||
from pilot.logs import logger
|
|
||||||
from pilot.vector_store.base import VectorStoreBase
|
from pilot.vector_store.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user