feat:multi-llm add model param

This commit is contained in:
aries_ckt 2023-09-07 20:43:53 +08:00
commit 1a2bf96767
19 changed files with 118 additions and 115 deletions

View File

@ -7,8 +7,8 @@ categories:
labels: enhancement labels: enhancement
- title: 🐞 Bug fixes - title: 🐞 Bug fixes
labels: fix labels: fix
- title: ⚠️ Deprecations # - title: ⚠️ Deprecations
labels: deprecation # labels: deprecation
- title: 🛠️ Other improvements - title: 🛠️ Other improvements
labels: labels:
- documentation - documentation

20
.github/workflows/pr-labeler.yml vendored Normal file
View File

@ -0,0 +1,20 @@
name: Pull request labeler
on:
pull_request_target:
types: [opened, edited]
permissions:
contents: read
pull-requests: write
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Label pull request
uses: release-drafter/release-drafter@v5
with:
disable-releaser: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

22
.github/workflows/release-drafter.yml vendored Normal file
View File

@ -0,0 +1,22 @@
name: Update draft releases
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: write
pull-requests: read
jobs:
main:
runs-on: ubuntu-latest
steps:
- name: Draft DB-GPT release
uses: release-drafter/release-drafter@v5
with:
config-name: release-drafter.yml
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -1,23 +0,0 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
WORKDIR /app
RUN apt-get update && apt-get install -y \
git \
python3 \
pip
# upgrade pip
RUN pip3 install --upgrade pip
COPY ./requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
RUN python3 -m spacy download zh_core_web_sm
COPY . /app
EXPOSE 7860
EXPOSE 8000

View File

@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base - Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL - Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
- Multi LLMs Support, Supports multiple large language models, currently supporting - Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b) - 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b) - 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b) - WizardLM-v1.2(13b)

View File

@ -30,12 +30,7 @@
</a> </a>
</p> </p>
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**Documents [**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信
**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community
**](https://github.com/eosphoros-ai/community)
</div>
## DB-GPT 是什么? ## DB-GPT 是什么?
@ -119,6 +114,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL - 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持 - 多模型支持
- 支持多种大语言模型, 当前已支持如下模型: - 支持多种大语言模型, 当前已支持如下模型:
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b) - 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b) - 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b) - WizardLM-v1.2(13b)

View File

@ -1,19 +0,0 @@
# Security Policy
## Supported Versions
Use this section to tell people about which versions of your project are
currently being supported with security updates.
| Version | Supported |
| ------- | ------------------ |
| v0.0.4 | :no new features |
| v0.0.3 | :documents QA |
| v0.0.2 | :sql generator |
| v0.0.1 | :minst runable |
## Reporting a Vulnerability
Use this section to tell people how to report a vulnerability.
we will build somethings small.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 380 KiB

After

Width:  |  Height:  |  Size: 256 KiB

View File

@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"), "baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b" # please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"), "baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2 # (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"), "wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),

View File

@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
chat scene select param chat scene select param
""" """
select_param: str = None select_param: str = None
"""
llm model name
"""
model_name: str = None
class MessageVo(BaseModel): class MessageVo(BaseModel):

View File

@ -2,7 +2,7 @@ import datetime
import traceback import traceback
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List from typing import Any, List, Dict
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
@ -33,12 +33,12 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None self, chat_param: Dict
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_param["chat_session_id"]
self.chat_mode = chat_mode self.chat_mode = chat_param["chat_mode"]
self.current_user_input: str = current_user_input self.current_user_input: str = chat_param["current_user_input"]
self.llm_model = CFG.LLM_MODEL self.llm_model = chat_param["model_name"]
self.llm_echo = False self.llm_echo = False
### load prompt template ### load prompt template
@ -55,14 +55,14 @@ class BaseChat(ABC):
) )
### can configurable storage methods ### can configurable storage methods
self.memory = DuckdbHistoryMemory(chat_session_id) self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
self.history_message: List[OnceConversation] = self.memory.messages() self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value()) self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
if select_param: if chat_param["select_param"]:
if len(chat_mode.param_types()) > 0: if len(self.chat_mode.param_types()) > 0:
self.current_message.param_type = chat_mode.param_types()[0] self.current_message.param_type = self.chat_mode.param_types()[0]
self.current_message.param_value = select_param self.current_message.param_value = chat_param["select_param"]
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
class Config: class Config:

View File

@ -1,7 +1,7 @@
import json import json
import os import os
import uuid import uuid
from typing import List from typing import List, Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
def __init__( def __init__(
self, self,
chat_session_id, chat_param: Dict
user_input,
select_param: str = "",
report_name: str = "report",
): ):
""" """ """ """
self.db_name = select_param self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatDashboard
super().__init__( super().__init__(
chat_mode=ChatScene.ChatDashboard, chat_param=chat_param
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
) )
if not self.db_name: if not self.db_name:
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!") raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
self.db_name = self.db_name self.db_name = self.db_name
self.report_name = report_name self.report_name = chat_param["report_name"] or "report"
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 5 self.top_k: int = 5
self.dashboard_template = self.__load_dashboard_template(report_name) self.dashboard_template = self.__load_dashboard_template(self.report_name)
def __load_dashboard_template(self, template_name): def __load_dashboard_template(self, template_name):
current_dir = os.getcwd() current_dir = os.getcwd()

View File

@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value() chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1 chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatExcel chat_mode = ChatScene.ChatExcel
self.select_param = select_param self.select_param = chat_param["select_param"]
if has_path(select_param): self.model_name = chat_param["model_name"]
self.excel_reader = ExcelReader(select_param) chat_param["chat_mode"] = ChatScene.ChatExcel
if has_path(self.select_param):
self.excel_reader = ExcelReader(self.select_param)
else: else:
self.excel_reader = ExcelReader( self.excel_reader = ExcelReader(
os.path.join( os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
) )
) )
super().__init__( super().__init__(
chat_mode=chat_mode, chat_param=chat_param
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
) )
def _generate_command_string(self, command: Dict[str, Any]) -> str: def _generate_command_string(self, command: Dict[str, Any]) -> str:
@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
"parent_mode": self.chat_mode, "parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name, "select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader, "excel_reader": self.excel_reader,
"model_name": self.model_name,
} }
learn_chat = ExcelLearning(**chat_param) learn_chat = ExcelLearning(**chat_param)
result = await learn_chat.nostream_call() result = await learn_chat.nostream_call()

View File

@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
parent_mode: Any = None, parent_mode: Any = None,
select_param: str = None, select_param: str = None,
excel_reader: Any = None, excel_reader: Any = None,
model_name: str = None,
): ):
chat_mode = ChatScene.ExcelLearning chat_mode = ChatScene.ExcelLearning
""" """ """ """
self.excel_file_path = select_param self.excel_file_path = select_param
self.excel_reader = excel_reader self.excel_reader = excel_reader
chat_param = {
"chat_mode": chat_mode,
"chat_session_id": chat_session_id,
"current_user_input": user_input,
"select_param": select_param,
"model_name": model_name,
}
super().__init__( super().__init__(
chat_mode=chat_mode, chat_param=chat_param
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=select_param,
) )
if parent_mode: if parent_mode:
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_param: Dict):
chat_mode = ChatScene.ChatWithDbExecute chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = chat_mode
""" """ """ """
super().__init__( super().__init__(
chat_mode=chat_mode, chat_param=chat_param,
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
) )
if not self.db_name: if not self.db_name:
raise ValueError( raise ValueError(

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_param: Dict):
""" """ """ """
self.db_name = select_param self.db_name = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
super().__init__( super().__init__(
chat_mode=ChatScene.ChatWithDbQA, chat_param=chat_param
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.db_name,
) )
if self.db_name: if self.db_name:

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None select_plugin: str = None
def __init__(self, chat_session_id, user_input, select_param: str = None): def __init__(self, chat_param: Dict):
self.plugin_selector = select_param self.plugin_selector = chat_param.select_param
chat_param["chat_mode"] = ChatScene.ChatExecution
super().__init__( super().__init__(
chat_mode=ChatScene.ChatExecution, chat_param=chat_param
chat_session_id=chat_session_id,
current_user_input=user_input,
select_param=self.plugin_selector,
) )
self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator = PluginPromptGenerator()
self.plugins_prompt_generator.command_registry = CFG.command_registry self.plugins_prompt_generator.command_registry = CFG.command_registry

View File

@ -1,3 +1,5 @@
from typing import Dict
from chromadb.errors import NoIndexException from chromadb.errors import NoIndexException
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None): def __init__(self, chat_param: Dict):
""" """ """ """
from pilot.embedding_engine.embedding_engine import EmbeddingEngine from pilot.embedding_engine.embedding_engine import EmbeddingEngine
self.knowledge_space = select_param self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
super().__init__( super().__init__(
chat_mode=ChatScene.ChatKnowledge, chat_param=chat_param,
chat_session_id=chat_session_id,
current_user_input=user_input,
) )
self.space_context = self.get_space_context(self.knowledge_space) self.space_context = self.get_space_context(self.knowledge_space)
self.top_k = ( self.top_k = (

View File

@ -1,3 +1,5 @@
from typing import Dict
from pilot.scene.base_chat import BaseChat from pilot.scene.base_chat import BaseChat
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.configs.config import Config from pilot.configs.config import Config
@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None): def __init__(self, chat_param: Dict):
""" """ """ """
chat_param["chat_mode"] = ChatScene.ChatNormal
super().__init__( super().__init__(
chat_mode=ChatScene.ChatNormal, chat_param=chat_param,
chat_session_id=chat_session_id,
current_user_input=user_input,
) )
def generate_input_values(self): def generate_input_values(self):