mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(core): Multi-module dependency splitting
This commit is contained in:
parent
d31a220c6e
commit
33506b062b
10
.github/release-drafter.yml
vendored
10
.github/release-drafter.yml
vendored
@ -76,6 +76,16 @@ autolabeler:
|
||||
# feat(connection): Support xxxx
|
||||
# fix(connection): Fix xxx
|
||||
- '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)\(.*connection.*\)/'
|
||||
- label: core
|
||||
title:
|
||||
# feat(core): Support xxxx
|
||||
# fix(core): Fix xxx
|
||||
- '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)\(.*core.*\)/'
|
||||
- label: web
|
||||
title:
|
||||
# feat(web): Support xxxx
|
||||
# fix(web): Fix xxx
|
||||
- '/^(build|chore|ci|depr|docs|feat|fix|perf|refactor|release|test)\(.*web.*\)/'
|
||||
- label: build
|
||||
title:
|
||||
- '/^build/'
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -28,6 +28,8 @@ sdist/
|
||||
var/
|
||||
wheels/
|
||||
models/
|
||||
# Soft link
|
||||
models
|
||||
plugins/
|
||||
|
||||
pip-wheel-metadata/
|
||||
|
@ -86,7 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
- Unified vector storage/indexing of knowledge base
|
||||
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
|
||||
- Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||
- 🔥 InternLM(7b)
|
||||
- 🔥 InternLM(7b,20b)
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
|
@ -119,7 +119,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
|
||||
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
|
||||
- 多模型支持
|
||||
- 支持多种大语言模型, 当前已支持如下模型:
|
||||
- 🔥 InternLM(7b)
|
||||
- 🔥 InternLM(7b,20b)
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
|
@ -11,10 +11,12 @@ ARG LANGUAGE="en"
|
||||
ARG PIP_INDEX_URL="https://pypi.org/simple"
|
||||
ENV PIP_INDEX_URL=$PIP_INDEX_URL
|
||||
|
||||
ARG DB_GPT_INSTALL_MODEL="default"
|
||||
ENV DB_GPT_INSTALL_MODEL=$DB_GPT_INSTALL_MODEL
|
||||
|
||||
RUN mkdir -p /app
|
||||
|
||||
# COPY only requirements.txt first to leverage Docker cache
|
||||
COPY ./requirements.txt /app/requirements.txt
|
||||
COPY ./setup.py /app/setup.py
|
||||
COPY ./README.md /app/README.md
|
||||
|
||||
@ -26,9 +28,9 @@ WORKDIR /app
|
||||
# RUN pip3 install -i $PIP_INDEX_URL ".[all]"
|
||||
|
||||
RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
|
||||
&& pip3 install -i $PIP_INDEX_URL . \
|
||||
# && pip3 install -i $PIP_INDEX_URL ".[llama_cpp]" \
|
||||
&& (if [ "${LANGUAGE}" = "zh" ]; \
|
||||
&& pip3 install -i $PIP_INDEX_URL ".[$DB_GPT_INSTALL_MODEL]"
|
||||
|
||||
RUN (if [ "${LANGUAGE}" = "zh" ]; \
|
||||
# language is zh, download zh_core_web_sm from github
|
||||
then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
|
||||
&& pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \
|
||||
@ -58,4 +60,4 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
|
||||
ENV PYTHONPATH "/app:$PYTHONPATH"
|
||||
EXPOSE 5000
|
||||
|
||||
CMD ["python3", "pilot/server/dbgpt_server.py"]
|
||||
CMD ["dbgpt", "start", "webserver"]
|
@ -4,14 +4,21 @@ SCRIPT_LOCATION=$0
|
||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||
WORK_DIR=$(pwd)
|
||||
|
||||
BASE_IMAGE="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
||||
BASE_IMAGE_DEFAULT="nvidia/cuda:11.8.0-runtime-ubuntu22.04"
|
||||
BASE_IMAGE_DEFAULT_CPU="ubuntu:22.04"
|
||||
|
||||
BASE_IMAGE=$BASE_IMAGE_DEFAULT
|
||||
IMAGE_NAME="eosphorosai/dbgpt"
|
||||
IMAGE_NAME_ARGS=""
|
||||
|
||||
# zh: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
PIP_INDEX_URL="https://pypi.org/simple"
|
||||
# en or zh
|
||||
LANGUAGE="en"
|
||||
BUILD_LOCAL_CODE="false"
|
||||
LOAD_EXAMPLES="true"
|
||||
BUILD_NETWORK=""
|
||||
DB_GPT_INSTALL_MODEL="default"
|
||||
|
||||
usage () {
|
||||
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-runtime-ubuntu22.04] [--image-name db-gpt]"
|
||||
@ -21,6 +28,8 @@ usage () {
|
||||
echo " [--language en or zh] You language, default: en"
|
||||
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
|
||||
echo " [--load-examples true or false] Whether to load examples to default database default: true"
|
||||
echo " [--network network name] The network of docker build"
|
||||
echo " [--install-mode mode name] Installation mode name, default: default, If you completely use openai's service, you can set the mode name to 'openai'"
|
||||
echo " [-h|--help] Usage message"
|
||||
}
|
||||
|
||||
@ -33,7 +42,7 @@ while [[ $# -gt 0 ]]; do
|
||||
shift # past value
|
||||
;;
|
||||
-n|--image-name)
|
||||
IMAGE_NAME="$2"
|
||||
IMAGE_NAME_ARGS="$2"
|
||||
shift # past argument
|
||||
shift # past value
|
||||
;;
|
||||
@ -57,6 +66,20 @@ while [[ $# -gt 0 ]]; do
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
--network)
|
||||
BUILD_NETWORK=" --network $2 "
|
||||
shift # past argument
|
||||
shift # past value
|
||||
;;
|
||||
-h|--help)
|
||||
help="true"
|
||||
shift
|
||||
;;
|
||||
--install-mode)
|
||||
DB_GPT_INSTALL_MODEL="$2"
|
||||
shift # past argument
|
||||
shift # past value
|
||||
;;
|
||||
-h|--help)
|
||||
help="true"
|
||||
shift
|
||||
@ -73,11 +96,29 @@ if [[ $help ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
docker build \
|
||||
if [ "$DB_GPT_INSTALL_MODEL" != "default" ]; then
|
||||
IMAGE_NAME="$IMAGE_NAME-$DB_GPT_INSTALL_MODEL"
|
||||
echo "install mode is not 'default', set image name to: ${IMAGE_NAME}"
|
||||
fi
|
||||
|
||||
if [ -z "$IMAGE_NAME_ARGS" ]; then
|
||||
if [ "$DB_GPT_INSTALL_MODEL" == "openai" ]; then
|
||||
# Use cpu image
|
||||
BASE_IMAGE=$BASE_IMAGE_DEFAULT_CPU
|
||||
fi
|
||||
else
|
||||
# User input image is not empty
|
||||
BASE_IMAGE=$IMAGE_NAME_ARGS
|
||||
fi
|
||||
|
||||
echo "Begin build docker image, base image: ${BASE_IMAGE}, target image name: ${IMAGE_NAME}"
|
||||
|
||||
docker build $BUILD_NETWORK \
|
||||
--build-arg BASE_IMAGE=$BASE_IMAGE \
|
||||
--build-arg PIP_INDEX_URL=$PIP_INDEX_URL \
|
||||
--build-arg LANGUAGE=$LANGUAGE \
|
||||
--build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \
|
||||
--build-arg LOAD_EXAMPLES=$LOAD_EXAMPLES \
|
||||
--build-arg DB_GPT_INSTALL_MODEL=$DB_GPT_INSTALL_MODEL \
|
||||
-f Dockerfile \
|
||||
-t $IMAGE_NAME $WORK_DIR/../../
|
||||
|
@ -6,7 +6,7 @@ Local cluster deployment
|
||||
|
||||
**Installing Command-Line Tool**
|
||||
|
||||
All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e .`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command.
|
||||
All operations below are performed using the `dbgpt` command. To use the `dbgpt` command, you need to install the DB-GPT project with `pip install -e ".[default]"`. Alternatively, you can use `python pilot/scripts/cli_scripts.py` as a substitute for the `dbgpt` command.
|
||||
|
||||
### Launch Model Controller
|
||||
|
||||
|
@ -49,7 +49,7 @@ For the entire installation process of DB-GPT, we use the miniconda3 virtual env
|
||||
python>=3.10
|
||||
conda create -n dbgpt_env python=3.10
|
||||
conda activate dbgpt_env
|
||||
pip install -e .
|
||||
pip install -e ".[default]"
|
||||
```
|
||||
Before use DB-GPT Knowledge
|
||||
```bash
|
||||
|
@ -6,7 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
|
||||
|
||||
|
||||
Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||
- 🔥 InternLM(7b)
|
||||
- 🔥 InternLM(7b,20b)
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""加载组件"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
@ -8,17 +9,19 @@ import requests
|
||||
import threading
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
@ -115,7 +118,7 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List["AutoGPTPluginTemplate"]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
|
@ -1,11 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.component import SystemApp
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""Configuration class to store the state of bools for different scripts access"""
|
||||
@ -99,9 +104,8 @@ class Config(metaclass=Singleton):
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
|
||||
@ -189,9 +193,7 @@ class Config(metaclass=Singleton):
|
||||
### Log level
|
||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
from pilot.component import SystemApp
|
||||
|
||||
self.SYSTEM_APP: SystemApp = None
|
||||
self.SYSTEM_APP: Optional["SystemApp"] = None
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
|
@ -23,15 +23,18 @@ os.chdir(new_directory)
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
import torch
|
||||
try:
|
||||
import torch
|
||||
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
@ -70,8 +73,9 @@ LLM_MODEL_CONFIG = {
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
|
@ -1,6 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.embedding_engine.embedding_factory import (
|
||||
@ -69,10 +68,10 @@ class EmbeddingEngine:
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
try:
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = vector_client.similar_search(text, 1)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
# except NotEnoughElementsException:
|
||||
# ans = vector_client.similar_search(text, 1)
|
||||
return ans
|
||||
|
||||
def vector_exist(self):
|
||||
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import markdown
|
||||
|
@ -3,7 +3,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
@ -71,10 +70,9 @@ class SourceEmbedding(ABC):
|
||||
self.vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
try:
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = self.vector_client.similar_search(doc, 1)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
# ans = self.vector_client.similar_search(doc, 1)
|
||||
return ans
|
||||
|
||||
def vector_name_exist(self):
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import platform
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
from pilot.configs.model_config import get_device
|
||||
@ -12,7 +11,7 @@ from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultModelWorker(ModelWorker):
|
||||
@ -91,8 +90,13 @@ class DefaultModelWorker(ModelWorker):
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
import torch
|
||||
torch_imported = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
@ -117,16 +121,17 @@ class DefaultModelWorker(ModelWorker):
|
||||
)
|
||||
yield model_output
|
||||
print(f"\n\nfull stream output:\n{previous_response}")
|
||||
except torch.cuda.CudaError:
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
else:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
yield model_output
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
@ -12,7 +13,6 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.component import SystemApp
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
ModelOutput,
|
||||
@ -30,15 +30,13 @@ from pilot.model.cluster.manager_base import (
|
||||
WorkerManagerFactory,
|
||||
)
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.utils import build_logger
|
||||
from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
_dict_to_command_args,
|
||||
)
|
||||
|
||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
|
@ -1,4 +1,3 @@
|
||||
import bardapi
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
@ -52,6 +51,8 @@ def bard_generate_stream(
|
||||
else:
|
||||
yield f"bard proxy url request failed!, response = {str(response)}"
|
||||
else:
|
||||
import bardapi
|
||||
|
||||
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
|
||||
|
||||
if response is not None and response.get("content") is not None:
|
||||
|
@ -10,9 +10,6 @@ from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from pilot.json_utils.utilities import DateTimeEncoder
|
||||
|
@ -1,7 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@ -59,22 +57,19 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"][
|
||||
"scene"
|
||||
]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
except NoIndexException:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
if not docs:
|
||||
raise ValueError(
|
||||
"you have no knowledge space, please add your knowledge space"
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
|
@ -71,7 +71,7 @@ def load(
|
||||
skip_wrong_doc: bool,
|
||||
max_workers: int,
|
||||
):
|
||||
"""Load you local knowledge to DB-GPT"""
|
||||
"""Load your local knowledge to DB-GPT"""
|
||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -170,8 +170,9 @@ class DBSummaryClient:
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
|
||||
vector_store_name = dbname + "_profile"
|
||||
profile_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"vector_store_name": vector_store_name,
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"embeddings": embeddings,
|
||||
@ -190,6 +191,8 @@ class DBSummaryClient:
|
||||
)
|
||||
docs.extend(embedding.read_batch())
|
||||
embedding.index_to_store(docs)
|
||||
else:
|
||||
logger.info(f"Vector store name {vector_store_name} exist")
|
||||
logger.info("init db profile success...")
|
||||
|
||||
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb import PersistentClient
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
@ -18,15 +19,18 @@ class ChromaStore(VectorStoreBase):
|
||||
ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb"
|
||||
)
|
||||
chroma_settings = Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
|
||||
persist_directory=self.persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
|
||||
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir,
|
||||
embedding_function=self.embeddings,
|
||||
client_settings=chroma_settings,
|
||||
# client_settings=chroma_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
)
|
||||
|
||||
@ -35,9 +39,13 @@ class ChromaStore(VectorStoreBase):
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
def vector_name_exists(self):
|
||||
return (
|
||||
os.path.exists(self.persist_dir) and len(os.listdir(self.persist_dir)) > 0
|
||||
)
|
||||
logger.info(f"Check persist_dir: {self.persist_dir}")
|
||||
if not os.path.exists(self.persist_dir):
|
||||
return False
|
||||
files = os.listdir(self.persist_dir)
|
||||
# Skip default file: chroma.sqlite3
|
||||
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
||||
return len(files) > 0
|
||||
|
||||
def load_document(self, documents):
|
||||
logger.info("ChromaStore load document")
|
||||
|
@ -1,84 +0,0 @@
|
||||
# torch==2.0.0
|
||||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
async-timeout==4.0.2
|
||||
attrs==22.2.0
|
||||
cchardet==2.1.7
|
||||
chardet==5.1.0
|
||||
# contourpy==1.0.7
|
||||
# cycler==0.11.0
|
||||
filelock==3.9.0
|
||||
fonttools==4.38.0
|
||||
frozenlist==1.3.3
|
||||
huggingface-hub==0.14.1
|
||||
importlib-resources==5.12.0
|
||||
|
||||
sqlparse==0.4.4
|
||||
# kiwisolver==1.4.4
|
||||
# matplotlib==3.7.1
|
||||
multidict==6.0.4
|
||||
packaging==23.0
|
||||
psutil==5.9.4
|
||||
# pycocotools==2.0.6
|
||||
# pyparsing==3.0.9
|
||||
python-dateutil==2.8.2
|
||||
pyyaml==6.0
|
||||
tokenizers==0.13.2
|
||||
tqdm==4.64.1
|
||||
transformers>=4.31.0
|
||||
transformers_stream_generator
|
||||
# timm==0.6.13
|
||||
spacy==3.5.3
|
||||
webdataset==0.2.48
|
||||
yarl==1.8.2
|
||||
zipp==3.14.0
|
||||
omegaconf==2.3.0
|
||||
opencv-python==4.7.0.72
|
||||
iopath==0.1.10
|
||||
tenacity==8.2.2
|
||||
peft
|
||||
# TODO remove pycocoevalcap
|
||||
pycocoevalcap
|
||||
cpm_kernels
|
||||
umap-learn
|
||||
# notebook
|
||||
gradio==3.23
|
||||
gradio-client==0.0.8
|
||||
# wandb
|
||||
# llama-index==0.5.27
|
||||
|
||||
# TODO move bitsandbytes to optional
|
||||
# bitsandbytes
|
||||
accelerate>=0.20.3
|
||||
|
||||
unstructured==0.6.3
|
||||
gpt4all==0.3.0
|
||||
diskcache==5.6.1
|
||||
seaborn
|
||||
auto-gpt-plugin-template
|
||||
pymdown-extensions
|
||||
gTTS==2.3.1
|
||||
langchain>=0.0.286
|
||||
nltk
|
||||
python-dotenv==1.0.0
|
||||
|
||||
vcrpy
|
||||
chromadb==0.3.22
|
||||
markdown2
|
||||
colorama
|
||||
playsound
|
||||
distro
|
||||
pypdf
|
||||
weaviate-client
|
||||
bardapi==0.1.29
|
||||
|
||||
# database
|
||||
|
||||
# TODO moved to optional dependencies
|
||||
pymysql
|
||||
duckdb
|
||||
duckdb-engine
|
||||
|
||||
# cli
|
||||
prettytable
|
||||
cachetools
|
@ -1,4 +1,4 @@
|
||||
# Testing dependencies
|
||||
# Testing and dev dependencies
|
||||
pytest
|
||||
asynctest
|
||||
pytest-asyncio
|
||||
@ -7,4 +7,6 @@ pytest-cov
|
||||
pytest-integration
|
||||
pytest-mock
|
||||
pytest-recording
|
||||
pytesseract==0.3.10
|
||||
pytesseract==0.3.10
|
||||
# python code format
|
||||
black
|
117
setup.py
117
setup.py
@ -10,7 +10,6 @@ from urllib.parse import urlparse, quote
|
||||
import re
|
||||
from pip._internal.utils.appdirs import user_cache_dir
|
||||
import shutil
|
||||
import tempfile
|
||||
from setuptools import find_packages
|
||||
|
||||
with open("README.md", mode="r", encoding="utf-8") as fh:
|
||||
@ -74,7 +73,6 @@ def cache_package(package_url: str, package_name: str, is_windows: bool = False)
|
||||
|
||||
local_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(local_path):
|
||||
# temp_file, temp_path = tempfile.mkstemp()
|
||||
temp_path = local_path + ".tmp"
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
@ -204,23 +202,16 @@ def torch_requires(
|
||||
torchvision_version: str = "0.15.1",
|
||||
torchaudio_version: str = "2.0.1",
|
||||
):
|
||||
torch_pkgs = []
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
f"torchvision=={torchvision_version}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
torch_cuda_pkgs = []
|
||||
os_type, _ = get_cpu_avx_support()
|
||||
if os_type == OSType.DARWIN:
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
f"torchvision=={torchvision_version}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
else:
|
||||
if os_type != OSType.DARWIN:
|
||||
cuda_version = get_cuda_version()
|
||||
if not cuda_version:
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
f"torchvision=={torchvision_version}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
else:
|
||||
if cuda_version:
|
||||
supported_versions = ["11.7", "11.8"]
|
||||
if cuda_version not in supported_versions:
|
||||
print(
|
||||
@ -238,12 +229,16 @@ def torch_requires(
|
||||
torchvision_url_cached = cache_package(
|
||||
torchvision_url, "torchvision", os_type == OSType.WINDOWS
|
||||
)
|
||||
torch_pkgs = [
|
||||
|
||||
torch_cuda_pkgs = [
|
||||
f"torch @ {torch_url_cached}",
|
||||
f"torchvision @ {torchvision_url_cached}",
|
||||
f"torchaudio=={torchaudio_version}",
|
||||
]
|
||||
|
||||
setup_spec.extras["torch"] = torch_pkgs
|
||||
setup_spec.extras["torch_cpu"] = torch_pkgs
|
||||
setup_spec.extras["torch_cuda"] = torch_cuda_pkgs
|
||||
|
||||
|
||||
def llama_cpp_python_cuda_requires():
|
||||
@ -274,6 +269,57 @@ def llama_cpp_python_cuda_requires():
|
||||
setup_spec.extras["llama_cpp"].append(f"llama_cpp_python_cuda @ {extra_index_url}")
|
||||
|
||||
|
||||
def core_requires():
|
||||
"""
|
||||
pip install db-gpt or pip install "db-gpt[core]"
|
||||
"""
|
||||
setup_spec.extras["core"] = [
|
||||
"aiohttp==3.8.4",
|
||||
"chardet==5.1.0",
|
||||
"importlib-resources==5.12.0",
|
||||
"psutil==5.9.4",
|
||||
"python-dotenv==1.0.0",
|
||||
"colorama",
|
||||
"prettytable",
|
||||
"cachetools",
|
||||
]
|
||||
|
||||
setup_spec.extras["framework"] = [
|
||||
"httpx",
|
||||
"sqlparse==0.4.4",
|
||||
"seaborn",
|
||||
# https://github.com/eosphoros-ai/DB-GPT/issues/551
|
||||
"pandas==2.0.3",
|
||||
"auto-gpt-plugin-template",
|
||||
"gTTS==2.3.1",
|
||||
"langchain>=0.0.286",
|
||||
"SQLAlchemy",
|
||||
"pymysql",
|
||||
"duckdb",
|
||||
"duckdb-engine",
|
||||
"jsonschema",
|
||||
# TODO move transformers to default
|
||||
"transformers>=4.31.0",
|
||||
]
|
||||
|
||||
|
||||
def knowledge_requires():
|
||||
"""
|
||||
pip install "db-gpt[knowledge]"
|
||||
"""
|
||||
setup_spec.extras["knowledge"] = [
|
||||
"spacy==3.5.3",
|
||||
# "chromadb==0.3.22",
|
||||
"chromadb",
|
||||
"markdown",
|
||||
"bs4",
|
||||
"python-pptx",
|
||||
"python-docx",
|
||||
"pypdf",
|
||||
"python-multipart",
|
||||
]
|
||||
|
||||
|
||||
def llama_cpp_requires():
|
||||
"""
|
||||
pip install "db-gpt[llama_cpp]"
|
||||
@ -309,6 +355,7 @@ def all_vector_store_requires():
|
||||
setup_spec.extras["vstore"] = [
|
||||
"grpcio==1.47.5", # maybe delete it
|
||||
"pymilvus==2.2.1",
|
||||
"weaviate-client",
|
||||
]
|
||||
|
||||
|
||||
@ -324,6 +371,31 @@ def openai_requires():
|
||||
pip install "db-gpt[openai]"
|
||||
"""
|
||||
setup_spec.extras["openai"] = ["openai", "tiktoken"]
|
||||
setup_spec.extras["openai"] += setup_spec.extras["framework"]
|
||||
setup_spec.extras["openai"] += setup_spec.extras["knowledge"]
|
||||
|
||||
|
||||
def gpt4all_requires():
|
||||
"""
|
||||
pip install "db-gpt[gpt4all]"
|
||||
"""
|
||||
setup_spec.extras["gpt4all"] = ["gpt4all"]
|
||||
|
||||
|
||||
def default_requires():
|
||||
"""
|
||||
pip install "db-gpt[default]"
|
||||
"""
|
||||
setup_spec.extras["default"] = [
|
||||
"tokenizers==0.13.2",
|
||||
"accelerate>=0.20.3",
|
||||
"sentence-transformers",
|
||||
"protobuf==3.20.3",
|
||||
]
|
||||
setup_spec.extras["default"] += setup_spec.extras["framework"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["knowledge"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["torch"]
|
||||
setup_spec.extras["default"] += setup_spec.extras["quantization"]
|
||||
|
||||
|
||||
def all_requires():
|
||||
@ -335,20 +407,23 @@ def all_requires():
|
||||
|
||||
|
||||
def init_install_requires():
|
||||
setup_spec.install_requires += parse_requirements("requirements.txt")
|
||||
setup_spec.install_requires += setup_spec.extras["torch"]
|
||||
setup_spec.install_requires += setup_spec.extras["quantization"]
|
||||
setup_spec.install_requires += setup_spec.extras["core"]
|
||||
print(f"Install requires: \n{','.join(setup_spec.install_requires)}")
|
||||
|
||||
|
||||
core_requires()
|
||||
torch_requires()
|
||||
knowledge_requires()
|
||||
llama_cpp_requires()
|
||||
quantization_requires()
|
||||
|
||||
all_vector_store_requires()
|
||||
all_datasource_requires()
|
||||
openai_requires()
|
||||
gpt4all_requires()
|
||||
|
||||
# must be last
|
||||
default_requires()
|
||||
all_requires()
|
||||
init_install_requires()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user