feat:multi-llm add model param

This commit is contained in:
aries_ckt
2023-09-07 20:43:53 +08:00
19 changed files with 118 additions and 115 deletions

View File

@@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),

View File

@@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
chat scene select param
"""
select_param: str = None
"""
llm model name
"""
model_name: str = None
class MessageVo(BaseModel):

View File

@@ -2,7 +2,7 @@ import datetime
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
@@ -33,12 +33,12 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
self, chat_param: Dict
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = chat_param["model_name"]
self.llm_echo = False
### load prompt template
@@ -55,14 +55,14 @@ class BaseChat(ABC):
)
### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_session_id)
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param:
if len(chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
if chat_param["select_param"]:
if len(self.chat_mode.param_types()) > 0:
self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0
class Config:

View File

@@ -1,7 +1,7 @@
import json
import os
import uuid
from typing import List
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
@@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
chat_param: Dict
):
""" """
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__(
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param
)
if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name
self.report_name = report_name
self.report_name = chat_param["report_name"] or "report"
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
self.dashboard_template = self.__load_dashboard_template(self.report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()

View File

@@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatExcel
self.select_param = select_param
if has_path(select_param):
self.excel_reader = ExcelReader(select_param)
self.select_param = chat_param["select_param"]
self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel
if has_path(self.select_param):
self.excel_reader = ExcelReader(self.select_param)
else:
self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
chat_param=chat_param
)
def _generate_command_string(self, command: Dict[str, Any]) -> str:
@@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
"parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader,
"model_name": self.model_name,
}
learn_chat = ExcelLearning(**chat_param)
result = await learn_chat.nostream_call()

View File

@@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
model_name: str = None,
):
chat_mode = ChatScene.ExcelLearning
""" """
self.excel_file_path = select_param
self.excel_reader = excel_reader
chat_param = {
"chat_mode": chat_mode,
"chat_session_id": chat_session_id,
"current_user_input": user_input,
"select_param": select_param,
"model_name": model_name,
}
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
chat_param=chat_param
)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()

View File

@@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = chat_mode
""" """
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param,
)
if not self.db_name:
raise ValueError(

View File

@@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
""" """
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param
)
if self.db_name:

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
@@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_session_id, user_input, select_param: str = None):
self.plugin_selector = select_param
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__(
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.plugin_selector,
chat_param=chat_param
)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry

View File

@@ -1,3 +1,5 @@
from typing import Dict
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat
@@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None):
def __init__(self, chat_param: Dict):
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
self.knowledge_space = select_param
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
super().__init__(
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
chat_param=chat_param,
)
self.space_context = self.get_space_context(self.knowledge_space)
self.top_k = (

View File

@@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None):
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ChatNormal
super().__init__(
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input,
chat_param=chat_param,
)
def generate_input_values(self):