mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-16 14:40:56 +00:00
feat:multi-llm add model param
This commit is contained in:
@@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
|
||||
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
|
||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||
|
@@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
|
||||
chat scene select param
|
||||
"""
|
||||
select_param: str = None
|
||||
"""
|
||||
llm model name
|
||||
"""
|
||||
model_name: str = None
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
|
@@ -2,7 +2,7 @@ import datetime
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
@@ -33,12 +33,12 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
|
||||
self, chat_param: Dict
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
self.chat_session_id = chat_param["chat_session_id"]
|
||||
self.chat_mode = chat_param["chat_mode"]
|
||||
self.current_user_input: str = chat_param["current_user_input"]
|
||||
self.llm_model = chat_param["model_name"]
|
||||
self.llm_echo = False
|
||||
|
||||
### load prompt template
|
||||
@@ -55,14 +55,14 @@ class BaseChat(ABC):
|
||||
)
|
||||
|
||||
### can configurable storage methods
|
||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||
if select_param:
|
||||
if len(chat_mode.param_types()) > 0:
|
||||
self.current_message.param_type = chat_mode.param_types()[0]
|
||||
self.current_message.param_value = select_param
|
||||
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
|
||||
if chat_param["select_param"]:
|
||||
if len(self.chat_mode.param_types()) > 0:
|
||||
self.current_message.param_type = self.chat_mode.param_types()[0]
|
||||
self.current_message.param_value = chat_param["select_param"]
|
||||
self.current_tokens_used: int = 0
|
||||
|
||||
class Config:
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
select_param: str = "",
|
||||
report_name: str = "report",
|
||||
chat_param: Dict
|
||||
):
|
||||
""" """
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatDashboard,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param
|
||||
)
|
||||
if not self.db_name:
|
||||
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
||||
self.db_name = self.db_name
|
||||
self.report_name = report_name
|
||||
self.report_name = chat_param["report_name"] or "report"
|
||||
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||
|
||||
self.top_k: int = 5
|
||||
self.dashboard_template = self.__load_dashboard_template(report_name)
|
||||
self.dashboard_template = self.__load_dashboard_template(self.report_name)
|
||||
|
||||
def __load_dashboard_template(self, template_name):
|
||||
current_dir = os.getcwd()
|
||||
|
@@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatExcel.value()
|
||||
chat_retention_rounds = 1
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
chat_mode = ChatScene.ChatExcel
|
||||
|
||||
self.select_param = select_param
|
||||
if has_path(select_param):
|
||||
self.excel_reader = ExcelReader(select_param)
|
||||
self.select_param = chat_param["select_param"]
|
||||
self.model_name = chat_param["model_name"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||
if has_path(self.select_param):
|
||||
self.excel_reader = ExcelReader(self.select_param)
|
||||
else:
|
||||
self.excel_reader = ExcelReader(
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=select_param,
|
||||
chat_param=chat_param
|
||||
)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
@@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
|
||||
"parent_mode": self.chat_mode,
|
||||
"select_param": self.excel_reader.excel_file_name,
|
||||
"excel_reader": self.excel_reader,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
learn_chat = ExcelLearning(**chat_param)
|
||||
result = await learn_chat.nostream_call()
|
||||
|
@@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
|
||||
parent_mode: Any = None,
|
||||
select_param: str = None,
|
||||
excel_reader: Any = None,
|
||||
model_name: str = None,
|
||||
):
|
||||
chat_mode = ChatScene.ExcelLearning
|
||||
""" """
|
||||
self.excel_file_path = select_param
|
||||
self.excel_reader = excel_reader
|
||||
chat_param = {
|
||||
"chat_mode": chat_mode,
|
||||
"chat_session_id": chat_session_id,
|
||||
"current_user_input": user_input,
|
||||
"select_param": select_param,
|
||||
"model_name": model_name,
|
||||
}
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=select_param,
|
||||
chat_param=chat_param
|
||||
)
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
chat_mode = ChatScene.ChatWithDbExecute
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = chat_mode
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
if not self.db_name:
|
||||
raise ValueError(
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param
|
||||
)
|
||||
|
||||
if self.db_name:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
self.plugin_selector = select_param
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.plugin_selector = chat_param.select_param
|
||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.plugin_selector,
|
||||
chat_param=chat_param
|
||||
)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
@@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
|
||||
self.knowledge_space = select_param
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
self.space_context = self.get_space_context(self.knowledge_space)
|
||||
self.top_k = (
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ChatNormal
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
|
Reference in New Issue
Block a user