Merge remote-tracking branch 'origin/main' into feat_llm_manage

This commit is contained in:
aries_ckt
2023-09-20 20:12:55 +08:00
30 changed files with 326 additions and 262 deletions

View File

@@ -1,4 +1,5 @@
"""加载组件"""
from __future__ import annotations
import json
import os
@@ -8,17 +9,19 @@ import requests
import threading
import datetime
from pathlib import Path
from typing import List
from typing import List, TYPE_CHECKING
from urllib.parse import urlparse
from zipimport import zipimporter
import requests
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.configs.config import Config
from pilot.configs.model_config import PLUGINS_DIR
from pilot.logs import logger
if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate
def inspect_zip_for_modules(zip_path: str, debug: bool = False) -> list[str]:
"""
@@ -115,7 +118,7 @@ def load_native_plugins(cfg: Config):
t.start()
def scan_plugins(cfg: Config, debug: bool = False) -> List[AutoGPTPluginTemplate]:
def scan_plugins(cfg: Config, debug: bool = False) -> List["AutoGPTPluginTemplate"]:
"""Scan the plugins directory for plugins and loads them.
Args:

View File

@@ -1,11 +1,16 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import annotations
import os
from typing import List
from typing import List, Optional, TYPE_CHECKING
from pilot.singleton import Singleton
if TYPE_CHECKING:
from auto_gpt_plugin_template import AutoGPTPluginTemplate
from pilot.component import SystemApp
class Config(metaclass=Singleton):
"""Configuration class to store the state of bools for different scripts access"""
@@ -99,9 +104,8 @@ class Config(metaclass=Singleton):
self.message_dir = os.getenv("MESSAGE_HISTORY_DIR", "../../message")
### The associated configuration parameters of the plug-in control the loading and use of the plug-in
from auto_gpt_plugin_template import AutoGPTPluginTemplate
self.plugins: List[AutoGPTPluginTemplate] = []
self.plugins: List["AutoGPTPluginTemplate"] = []
self.plugins_openai = []
self.plugins_auto_load = os.getenv("AUTO_LOAD_PLUGIN", "True") == "True"
@@ -189,9 +193,7 @@ class Config(metaclass=Singleton):
### Log level
self.DBGPT_LOG_LEVEL = os.getenv("DBGPT_LOG_LEVEL", "INFO")
from pilot.component import SystemApp
self.SYSTEM_APP: SystemApp = None
self.SYSTEM_APP: Optional["SystemApp"] = None
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""

View File

@@ -23,15 +23,18 @@ os.chdir(new_directory)
def get_device() -> str:
import torch
try:
import torch
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
return (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
except ModuleNotFoundError:
return "cpu"
LLM_MODEL_CONFIG = {
@@ -70,8 +73,9 @@ LLM_MODEL_CONFIG = {
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
# https://huggingface.co/internlm/internlm-chat-7b-v1_1, 7b vs 7b-v1.1: https://github.com/InternLM/InternLM/issues/288
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b-v1_1"),
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
}
EMBEDDING_MODEL_CONFIG = {

View File

@@ -1,6 +1,5 @@
from typing import Optional
from chromadb.errors import NotEnoughElementsException
from langchain.text_splitter import TextSplitter
from pilot.embedding_engine.embedding_factory import (
@@ -69,10 +68,10 @@ class EmbeddingEngine:
vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
try:
ans = vector_client.similar_search(text, topk)
except NotEnoughElementsException:
ans = vector_client.similar_search(text, 1)
# https://github.com/chroma-core/chroma/issues/657
ans = vector_client.similar_search(text, topk)
# except NotEnoughElementsException:
# ans = vector_client.similar_search(text, 1)
return ans
def vector_exist(self):

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from typing import List, Optional
import markdown

View File

@@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from chromadb.errors import NotEnoughElementsException
from langchain.text_splitter import TextSplitter
from pilot.vector_store.connector import VectorStoreConnector
@@ -71,10 +70,9 @@ class SourceEmbedding(ABC):
self.vector_client = VectorStoreConnector(
self.vector_store_config["vector_store_type"], self.vector_store_config
)
try:
ans = self.vector_client.similar_search(doc, topk)
except NotEnoughElementsException:
ans = self.vector_client.similar_search(doc, 1)
# https://github.com/chroma-core/chroma/issues/657
ans = self.vector_client.similar_search(doc, topk)
# ans = self.vector_client.similar_search(doc, 1)
return ans
def vector_name_exist(self):

View File

@@ -1,5 +1,4 @@
import logging
import platform
from typing import Dict, Iterator, List
from pilot.configs.model_config import get_device
@@ -12,7 +11,7 @@ from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker")
logger = logging.getLogger(__name__)
class DefaultModelWorker(ModelWorker):
@@ -91,8 +90,13 @@ class DefaultModelWorker(ModelWorker):
_clear_torch_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
import torch
torch_imported = False
try:
import torch
torch_imported = True
except ImportError:
pass
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
@@ -117,16 +121,17 @@ class DefaultModelWorker(ModelWorker):
)
yield model_output
print(f"\n\nfull stream output:\n{previous_response}")
except torch.cuda.CudaError:
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
yield model_output
except Exception as e:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
yield model_output
def generate(self, params: Dict) -> ModelOutput:

View File

@@ -5,6 +5,7 @@ import os
import sys
import random
import time
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from typing import Awaitable, Callable, Dict, Iterator, List, Optional
@@ -12,7 +13,6 @@ from typing import Awaitable, Callable, Dict, Iterator, List, Optional
from fastapi import APIRouter, FastAPI
from fastapi.responses import StreamingResponse
from pilot.component import SystemApp
from pilot.configs.model_config import LOGDIR
from pilot.model.base import (
ModelInstance,
ModelOutput,
@@ -30,15 +30,13 @@ from pilot.model.cluster.manager_base import (
WorkerManagerFactory,
)
from pilot.model.cluster.base import *
from pilot.utils import build_logger
from pilot.utils.parameter_utils import (
EnvArgumentParser,
ParameterDescription,
_dict_to_command_args,
)
logger = build_logger("model_worker", LOGDIR + "/model_worker.log")
logger = logging.getLogger(__name__)
RegisterFunc = Callable[[WorkerRunData], Awaitable[None]]
DeregisterFunc = Callable[[WorkerRunData], Awaitable[None]]

View File

@@ -1,4 +1,3 @@
import bardapi
import requests
from typing import List
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@@ -52,6 +51,8 @@ def bard_generate_stream(
else:
yield f"bard proxy url request failed!, response = {str(response)}"
else:
import bardapi
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
if response is not None and response.get("content") is not None:

View File

@@ -10,9 +10,6 @@ from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.common.markdown_text import (
generate_htm_table,
)
from pilot.scene.chat_data.chat_excel.excel_learning.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.json_utils.utilities import DateTimeEncoder

View File

@@ -1,7 +1,5 @@
from typing import Dict
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@@ -59,22 +57,19 @@ class ChatKnowledge(BaseChat):
)
def generate_input_values(self):
try:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"][
"scene"
]
self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)
context = [d.page_content for d in docs]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
except NoIndexException:
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"
)
context = [d.page_content for d in docs]
context = context[: self.max_token]
input_values = {"context": context, "question": self.current_user_input}
return input_values
@property

View File

@@ -71,7 +71,7 @@ def load(
skip_wrong_doc: bool,
max_workers: int,
):
"""Load you local knowledge to DB-GPT"""
"""Load your local knowledge to DB-GPT"""
from pilot.server.knowledge._cli.knowledge_client import knowledge_init
knowledge_init(

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -170,8 +170,9 @@ class DBSummaryClient:
def init_db_profile(self, db_summary_client, dbname, embeddings):
from pilot.embedding_engine.string_embedding import StringEmbedding
vector_store_name = dbname + "_profile"
profile_store_config = {
"vector_store_name": dbname + "_profile",
"vector_store_name": vector_store_name,
"chroma_persist_path": KNOWLEDGE_UPLOAD_ROOT_PATH,
"vector_store_type": CFG.VECTOR_STORE_TYPE,
"embeddings": embeddings,
@@ -190,6 +191,8 @@ class DBSummaryClient:
)
docs.extend(embedding.read_batch())
embedding.index_to_store(docs)
else:
logger.info(f"Vector store name {vector_store_name} exist")
logger.info("init db profile success...")

View File

@@ -2,6 +2,7 @@ import os
from typing import Any
from chromadb.config import Settings
from chromadb import PersistentClient
from pilot.logs import logger
from pilot.vector_store.base import VectorStoreBase
@@ -18,15 +19,18 @@ class ChromaStore(VectorStoreBase):
ctx["chroma_persist_path"], ctx["vector_store_name"] + ".vectordb"
)
chroma_settings = Settings(
chroma_db_impl="duckdb+parquet",
# chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma
persist_directory=self.persist_dir,
anonymized_telemetry=False,
)
client = PersistentClient(path=self.persist_dir, settings=chroma_settings)
collection_metadata = {"hnsw:space": "cosine"}
self.vector_store_client = Chroma(
persist_directory=self.persist_dir,
embedding_function=self.embeddings,
client_settings=chroma_settings,
# client_settings=chroma_settings,
client=client,
collection_metadata=collection_metadata,
)
@@ -35,9 +39,13 @@ class ChromaStore(VectorStoreBase):
return self.vector_store_client.similarity_search(text, topk)
def vector_name_exists(self):
return (
os.path.exists(self.persist_dir) and len(os.listdir(self.persist_dir)) > 0
)
logger.info(f"Check persist_dir: {self.persist_dir}")
if not os.path.exists(self.persist_dir):
return False
files = os.listdir(self.persist_dir)
# Skip default file: chroma.sqlite3
files = list(filter(lambda f: f != "chroma.sqlite3", files))
return len(files) > 0
def load_document(self, documents):
logger.info("ChromaStore load document")