mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 22:19:28 +00:00
Merge remote-tracking branch 'origin/main' into feat_llm_manage
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""加载组件"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
@@ -8,17 +9,19 @@ import requests
|
||||
import threading
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
from zipimport import zipimporter
|
||||
|
||||
import requests
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import PLUGINS_DIR
|
||||
from pilot.logs import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
|
||||
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
|
||||
"""
|
||||
@@ -115,7 +118,7 @@ def load_native_plugins(cfg: Config):
|
||||
t.start()
|
||||
|
||||
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
|
||||
def scan_plugins(cfg: Config, debug: bool = False) -> List["AutoGPTPluginTemplate"]:
|
||||
"""Scan the plugins directory for plugins and loads them.
|
||||
|
||||
Args:
|
||||
|
@@ -1,11 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from pilot.singleton import Singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pilot.component import SystemApp
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""Configuration class to store the state of bools for different scripts access"""
|
||||
@@ -99,9 +104,8 @@ class Config(metaclass=Singleton):
|
||||
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
|
||||
|
||||
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
|
||||
self.plugins: List[AutoGPTPluginTemplate] = []
|
||||
self.plugins: List["AutoGPTPluginTemplate"] = []
|
||||
self.plugins_openai = []
|
||||
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
|
||||
|
||||
@@ -189,9 +193,7 @@ class Config(metaclass=Singleton):
|
||||
### Log level
|
||||
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
|
||||
|
||||
from pilot.component import SystemApp
|
||||
|
||||
self.SYSTEM_APP: SystemApp = None
|
||||
self.SYSTEM_APP: Optional["SystemApp"] = None
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
|
@@ -23,15 +23,18 @@ os.chdir(new_directory)
|
||||
|
||||
|
||||
def get_device() -> str:
|
||||
import torch
|
||||
try:
|
||||
import torch
|
||||
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return "cpu"
|
||||
|
||||
|
||||
LLM_MODEL_CONFIG = {
|
||||
@@ -70,8 +73,9 @@ LLM_MODEL_CONFIG = {
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
|
@@ -1,6 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.embedding_engine.embedding_factory import (
|
||||
@@ -69,10 +68,10 @@ class EmbeddingEngine:
|
||||
vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
try:
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = vector_client.similar_search(text, 1)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = vector_client.similar_search(text, topk)
|
||||
# except NotEnoughElementsException:
|
||||
# ans = vector_client.similar_search(text, 1)
|
||||
return ans
|
||||
|
||||
def vector_exist(self):
|
||||
|
@@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import markdown
|
||||
|
@@ -3,7 +3,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from chromadb.errors import NotEnoughElementsException
|
||||
from langchain.text_splitter import TextSplitter
|
||||
|
||||
from pilot.vector_store.connector import VectorStoreConnector
|
||||
@@ -71,10 +70,9 @@ class SourceEmbedding(ABC):
|
||||
self.vector_client = VectorStoreConnector(
|
||||
self.vector_store_config["vector_store_type"], self.vector_store_config
|
||||
)
|
||||
try:
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
except NotEnoughElementsException:
|
||||
ans = self.vector_client.similar_search(doc, 1)
|
||||
# https://github.com/chroma-core/chroma/issues/657
|
||||
ans = self.vector_client.similar_search(doc, topk)
|
||||
# ans = self.vector_client.similar_search(doc, 1)
|
||||
return ans
|
||||
|
||||
def vector_name_exist(self):
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import platform
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
from pilot.configs.model_config import get_device
|
||||
@@ -12,7 +11,7 @@ from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DefaultModelWorker(ModelWorker):
|
||||
@@ -91,8 +90,13 @@ class DefaultModelWorker(ModelWorker):
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
import torch
|
||||
torch_imported = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
@@ -117,16 +121,17 @@ class DefaultModelWorker(ModelWorker):
|
||||
)
|
||||
yield model_output
|
||||
print(f"\n\nfull stream output:\n{previous_response}")
|
||||
except torch.cuda.CudaError:
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
yield model_output
|
||||
except Exception as e:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
else:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
yield model_output
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
|
@@ -5,6 +5,7 @@ import os
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import asdict
|
||||
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
@@ -12,7 +13,6 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pilot.component import SystemApp
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import (
|
||||
ModelInstance,
|
||||
ModelOutput,
|
||||
@@ -30,15 +30,13 @@ from pilot.model.cluster.manager_base import (
|
||||
WorkerManagerFactory,
|
||||
)
|
||||
from pilot.model.cluster.base import *
|
||||
from pilot.utils import build_logger
|
||||
from pilot.utils.parameter_utils import (
|
||||
EnvArgumentParser,
|
||||
ParameterDescription,
|
||||
_dict_to_command_args,
|
||||
)
|
||||
|
||||
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import bardapi
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
@@ -52,6 +51,8 @@ def bard_generate_stream(
|
||||
else:
|
||||
yield f"bard proxy url request failed!, response = {str(response)}"
|
||||
else:
|
||||
import bardapi
|
||||
|
||||
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
|
||||
|
||||
if response is not None and response.get("content") is not None:
|
||||
|
@@ -10,9 +10,6 @@ from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
from pilot.common.markdown_text import (
|
||||
generate_htm_table,
|
||||
)
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from pilot.json_utils.utilities import DateTimeEncoder
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@@ -59,22 +57,19 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
try:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"][
|
||||
"scene"
|
||||
]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
except NoIndexException:
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
if not docs:
|
||||
raise ValueError(
|
||||
"you have no knowledge space, please add your knowledge space"
|
||||
)
|
||||
context = [d.page_content for d in docs]
|
||||
context = context[: self.max_token]
|
||||
input_values = {"context": context, "question": self.current_user_input}
|
||||
return input_values
|
||||
|
||||
@property
|
||||
|
@@ -71,7 +71,7 @@ def load(
|
||||
skip_wrong_doc: bool,
|
||||
max_workers: int,
|
||||
):
|
||||
"""Load you local knowledge to DB-GPT"""
|
||||
"""Load your local knowledge to DB-GPT"""
|
||||
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
|
||||
|
||||
knowledge_init(
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -170,8 +170,9 @@ class DBSummaryClient:
|
||||
def init_db_profile(self, db_summary_client, dbname, embeddings):
|
||||
from pilot.embedding_engine.string_embedding import StringEmbedding
|
||||
|
||||
vector_store_name = dbname + "_profile"
|
||||
profile_store_config = {
|
||||
"vector_store_name": dbname + "_profile",
|
||||
"vector_store_name": vector_store_name,
|
||||
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
"vector_store_type": CFG.VECTOR_STORE_TYPE,
|
||||
"embeddings": embeddings,
|
||||
@@ -190,6 +191,8 @@ class DBSummaryClient:
|
||||
)
|
||||
docs.extend(embedding.read_batch())
|
||||
embedding.index_to_store(docs)
|
||||
else:
|
||||
logger.info(f"Vector store name {vector_store_name} exist")
|
||||
logger.info("init db profile success...")
|
||||
|
||||
|
||||
|
@@ -2,6 +2,7 @@ import os
|
||||
from typing import Any
|
||||
|
||||
from chromadb.config import Settings
|
||||
from chromadb import PersistentClient
|
||||
from pilot.logs import logger
|
||||
from pilot.vector_store.base import VectorStoreBase
|
||||
|
||||
@@ -18,15 +19,18 @@ class ChromaStore(VectorStoreBase):
|
||||
ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb"
|
||||
)
|
||||
chroma_settings = Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
|
||||
persist_directory=self.persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
)
|
||||
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
|
||||
|
||||
collection_metadata = {"hnsw:space": "cosine"}
|
||||
self.vector_store_client = Chroma(
|
||||
persist_directory=self.persist_dir,
|
||||
embedding_function=self.embeddings,
|
||||
client_settings=chroma_settings,
|
||||
# client_settings=chroma_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
)
|
||||
|
||||
@@ -35,9 +39,13 @@ class ChromaStore(VectorStoreBase):
|
||||
return self.vector_store_client.similarity_search(text, topk)
|
||||
|
||||
def vector_name_exists(self):
|
||||
return (
|
||||
os.path.exists(self.persist_dir) and len(os.listdir(self.persist_dir)) > 0
|
||||
)
|
||||
logger.info(f"Check persist_dir: {self.persist_dir}")
|
||||
if not os.path.exists(self.persist_dir):
|
||||
return False
|
||||
files = os.listdir(self.persist_dir)
|
||||
# Skip default file: chroma.sqlite3
|
||||
files = list(filter(lambda f: f != "chroma.sqlite3", files))
|
||||
return len(files) > 0
|
||||
|
||||
def load_document(self, documents):
|
||||
logger.info("ChromaStore load document")
|
||||
|
Reference in New Issue
Block a user