feat:multi-llm add model param

This commit is contained in:
aries_ckt 2023-09-07 20:43:53 +08:00
commit 1a2bf96767
19 changed files with 118 additions and 115 deletions

View File

@ -7,8 +7,8 @@ categories:
labels: enhancement
- title: 🐞 Bug fixes
labels: fix
- title: ⚠️ Deprecations
labels: deprecation
# - title: ⚠️ Deprecations
# labels: deprecation
- title: 🛠️ Other improvements
labels:
- documentation

20
.github/workflows/pr-labeler.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name: Pull request labeler
on:
pull_request_target:
types: [opened, edited]
permissions:
contents: read
pull-requests: write
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Label pull request
uses: release-drafter/release-drafter@v5
with:
disable-releaser: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

22
.github/workflows/release-drafter.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: Update draft releases
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: write
pull-requests: read
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Draft DB-GPT release
uses: release-drafter/release-drafter@v5
with:
config-name: release-drafter.yml
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,23 +0,0 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
WORKDIR /app
RUN apt-get update && apt-get install -y \
git \
python3 \
pip
# upgrade pip
RUN pip3 install --upgrade pip
COPY ./requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
RUN python3 -m spacy download zh_core_web_sm
COPY . /app
EXPOSE 7860
EXPOSE 8000

View File

@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
- Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)

View File

@ -30,12 +30,7 @@
</a>
</p>
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**Documents
**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信
**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community
**](https://github.com/eosphoros-ai/community)
</div>
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
## DB-GPT 是什么?
@ -119,6 +114,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持
- 支持多种大语言模型, 当前已支持如下模型:
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)

View File

@ -1,19 +0,0 @@
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ------- | ------------------ |
| v0.0.4 | :no new features |
| v0.0.3 | :documents QA |
| v0.0.2 | :sql generator |
| v0.0.1 | :minst runable |
## Reporting a Vulnerability
Use this section to tell people how to report a vulnerability.
we will build somethings small.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

After

Width:  |  Height:  |  Size: 256 KiB

View File

@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),

View File

@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
chat scene select param
"""
select_param: str = None
"""
llm model name
"""
model_name: str = None
class MessageVo(BaseModel):

View File

@ -2,7 +2,7 @@ import datetime
import traceback
import warnings
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any, List, Dict
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
@ -33,12 +33,12 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
self, chat_param: Dict
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
self.llm_model = CFG.LLM_MODEL
self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = chat_param["model_name"]
self.llm_echo = False
### load prompt template
@ -55,14 +55,14 @@ class BaseChat(ABC):
)
### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_session_id)
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
if select_param:
if len(chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0]
self.current_message.param_value = select_param
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
if chat_param["select_param"]:
if len(self.chat_mode.param_types()) > 0:
self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0
class Config:

View File

@ -1,7 +1,7 @@
import json
import os
import uuid
from typing import List
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
chat_param: Dict
):
""" """
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__(
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param
)
if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name
self.report_name = report_name
self.report_name = chat_param["report_name"] or "report"
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name)
self.dashboard_template = self.__load_dashboard_template(self.report_name)
def __load_dashboard_template(self, template_name):
current_dir = os.getcwd()

View File

@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatExcel
self.select_param = select_param
if has_path(select_param):
self.excel_reader = ExcelReader(select_param)
self.select_param = chat_param["select_param"]
self.model_name = chat_param["model_name"]
chat_param["chat_mode"] = ChatScene.ChatExcel
if has_path(self.select_param):
self.excel_reader = ExcelReader(self.select_param)
else:
self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
chat_param=chat_param
)
def _generate_command_string(self, command: Dict[str, Any]) -> str:
@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
"parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader,
"model_name": self.model_name,
}
learn_chat = ExcelLearning(**chat_param)
result = await learn_chat.nostream_call()

View File

@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
model_name: str = None,
):
chat_mode = ChatScene.ExcelLearning
""" """
self.excel_file_path = select_param
self.excel_reader = excel_reader
chat_param = {
"chat_mode": chat_mode,
"chat_session_id": chat_session_id,
"current_user_input": user_input,
"select_param": select_param,
"model_name": model_name,
}
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
chat_param=chat_param
)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = chat_mode
""" """
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param,
)
if not self.db_name:
raise ValueError(

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""):
def __init__(self, chat_param: Dict):
""" """
self.db_name = select_param
self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__(
chat_mode=ChatScene.ChatWithDbQA,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
chat_param=chat_param
)
if self.db_name:

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(self, chat_session_id, user_input, select_param: str = None):
self.plugin_selector = select_param
def __init__(self, chat_param: Dict):
self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__(
chat_mode=ChatScene.ChatExecution,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.plugin_selector,
chat_param=chat_param
)
self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry

View File

@ -1,3 +1,5 @@
from typing import Dict
from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat
@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None):
def __init__(self, chat_param: Dict):
""" """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
self.knowledge_space = select_param
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
super().__init__(
chat_mode=ChatScene.ChatKnowledge,
chat_session_id=chat_session_id,
current_user_input=user_input,
chat_param=chat_param,
)
self.space_context = self.get_space_context(self.knowledge_space)
self.top_k = (

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene
from pilot.configs.config import Config
@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None):
def __init__(self, chat_param: Dict):
""" """
chat_param["chat_mode"] = ChatScene.ChatNormal
super().__init__(
chat_mode=ChatScene.ChatNormal,
chat_session_id=chat_session_id,
current_user_input=user_input,
chat_param=chat_param,
)
def generate_input_values(self):