mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
feat:multi-llm add model param
This commit is contained in:
commit
1a2bf96767
4
.github/release-drafter.yml
vendored
4
.github/release-drafter.yml
vendored
@ -7,8 +7,8 @@ categories:
|
||||
labels: enhancement
|
||||
- title: 🐞 Bug fixes
|
||||
labels: fix
|
||||
- title: ⚠️ Deprecations
|
||||
labels: deprecation
|
||||
# - title: ⚠️ Deprecations
|
||||
# labels: deprecation
|
||||
- title: 🛠️ Other improvements
|
||||
labels:
|
||||
- documentation
|
||||
|
20
.github/workflows/pr-labeler.yml
vendored
Normal file
20
.github/workflows/pr-labeler.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
name: Pull request labeler
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, edited]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Label pull request
|
||||
uses: release-drafter/release-drafter@v5
|
||||
with:
|
||||
disable-releaser: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
22
.github/workflows/release-drafter.yml
vendored
Normal file
22
.github/workflows/release-drafter.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Update draft releases
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
main:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Draft DB-GPT release
|
||||
uses: release-drafter/release-drafter@v5
|
||||
with:
|
||||
config-name: release-drafter.yml
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
23
Dockerfile
23
Dockerfile
@ -1,23 +0,0 @@
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
pip
|
||||
|
||||
# upgrade pip
|
||||
RUN pip3 install --upgrade pip
|
||||
|
||||
COPY ./requirements.txt /app/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
RUN python3 -m spacy download zh_core_web_sm
|
||||
|
||||
|
||||
COPY . /app
|
||||
|
||||
EXPOSE 7860
|
||||
EXPOSE 8000
|
@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
- Unified vector storage/indexing of knowledge base
|
||||
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
|
||||
- Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- WizardLM-v1.2(13b)
|
||||
|
@ -30,12 +30,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**Documents
|
||||
**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信
|
||||
**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community
|
||||
**](https://github.com/eosphoros-ai/community)
|
||||
|
||||
</div>
|
||||
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
|
||||
|
||||
## DB-GPT 是什么?
|
||||
|
||||
@ -119,6 +114,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
|
||||
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
|
||||
- 多模型支持
|
||||
- 支持多种大语言模型, 当前已支持如下模型:
|
||||
- 🔥 Baichuan2(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- WizardLM-v1.2(13b)
|
||||
|
19
SECURITY.md
19
SECURITY.md
@ -1,19 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Use this section to tell people about which versions of your project are
|
||||
currently being supported with security updates.
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| v0.0.4 | :no new features |
|
||||
| v0.0.3 | :documents QA |
|
||||
| v0.0.2 | :sql generator |
|
||||
| v0.0.1 | :minst runable |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Use this section to tell people how to report a vulnerability.
|
||||
|
||||
we will build somethings small.
|
Binary file not shown.
Before Width: | Height: | Size: 380 KiB After Width: | Height: | Size: 256 KiB |
@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
|
||||
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
|
||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||
|
@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
|
||||
chat scene select param
|
||||
"""
|
||||
select_param: str = None
|
||||
"""
|
||||
llm model name
|
||||
"""
|
||||
model_name: str = None
|
||||
|
||||
|
||||
class MessageVo(BaseModel):
|
||||
|
@ -2,7 +2,7 @@ import datetime
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
@ -33,12 +33,12 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
|
||||
self, chat_param: Dict
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
self.llm_model = CFG.LLM_MODEL
|
||||
self.chat_session_id = chat_param["chat_session_id"]
|
||||
self.chat_mode = chat_param["chat_mode"]
|
||||
self.current_user_input: str = chat_param["current_user_input"]
|
||||
self.llm_model = chat_param["model_name"]
|
||||
self.llm_echo = False
|
||||
|
||||
### load prompt template
|
||||
@ -55,14 +55,14 @@ class BaseChat(ABC):
|
||||
)
|
||||
|
||||
### can configurable storage methods
|
||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
||||
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
||||
if select_param:
|
||||
if len(chat_mode.param_types()) > 0:
|
||||
self.current_message.param_type = chat_mode.param_types()[0]
|
||||
self.current_message.param_value = select_param
|
||||
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
|
||||
if chat_param["select_param"]:
|
||||
if len(self.chat_mode.param_types()) > 0:
|
||||
self.current_message.param_type = self.chat_mode.param_types()[0]
|
||||
self.current_message.param_value = chat_param["select_param"]
|
||||
self.current_tokens_used: int = 0
|
||||
|
||||
class Config:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_session_id,
|
||||
user_input,
|
||||
select_param: str = "",
|
||||
report_name: str = "report",
|
||||
chat_param: Dict
|
||||
):
|
||||
""" """
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatDashboard,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param
|
||||
)
|
||||
if not self.db_name:
|
||||
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
||||
self.db_name = self.db_name
|
||||
self.report_name = report_name
|
||||
self.report_name = chat_param["report_name"] or "report"
|
||||
|
||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||
|
||||
self.top_k: int = 5
|
||||
self.dashboard_template = self.__load_dashboard_template(report_name)
|
||||
self.dashboard_template = self.__load_dashboard_template(self.report_name)
|
||||
|
||||
def __load_dashboard_template(self, template_name):
|
||||
current_dir = os.getcwd()
|
||||
|
@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
|
||||
chat_scene: str = ChatScene.ChatExcel.value()
|
||||
chat_retention_rounds = 1
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
chat_mode = ChatScene.ChatExcel
|
||||
|
||||
self.select_param = select_param
|
||||
if has_path(select_param):
|
||||
self.excel_reader = ExcelReader(select_param)
|
||||
self.select_param = chat_param["select_param"]
|
||||
self.model_name = chat_param["model_name"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||
if has_path(self.select_param):
|
||||
self.excel_reader = ExcelReader(self.select_param)
|
||||
else:
|
||||
self.excel_reader = ExcelReader(
|
||||
os.path.join(
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=select_param,
|
||||
chat_param=chat_param
|
||||
)
|
||||
|
||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||
@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
|
||||
"parent_mode": self.chat_mode,
|
||||
"select_param": self.excel_reader.excel_file_name,
|
||||
"excel_reader": self.excel_reader,
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
learn_chat = ExcelLearning(**chat_param)
|
||||
result = await learn_chat.nostream_call()
|
||||
|
@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
|
||||
parent_mode: Any = None,
|
||||
select_param: str = None,
|
||||
excel_reader: Any = None,
|
||||
model_name: str = None,
|
||||
):
|
||||
chat_mode = ChatScene.ExcelLearning
|
||||
""" """
|
||||
self.excel_file_path = select_param
|
||||
self.excel_reader = excel_reader
|
||||
chat_param = {
|
||||
"chat_mode": chat_mode,
|
||||
"chat_session_id": chat_session_id,
|
||||
"current_user_input": user_input,
|
||||
"select_param": select_param,
|
||||
"model_name": model_name,
|
||||
}
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=select_param,
|
||||
chat_param=chat_param
|
||||
)
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
chat_mode = ChatScene.ChatWithDbExecute
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = chat_mode
|
||||
""" """
|
||||
super().__init__(
|
||||
chat_mode=chat_mode,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
if not self.db_name:
|
||||
raise ValueError(
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
self.db_name = select_param
|
||||
self.db_name = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatWithDbQA,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.db_name,
|
||||
chat_param=chat_param
|
||||
)
|
||||
|
||||
if self.db_name:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
|
||||
plugins_prompt_generator: PluginPromptGenerator
|
||||
select_plugin: str = None
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
self.plugin_selector = select_param
|
||||
def __init__(self, chat_param: Dict):
|
||||
self.plugin_selector = chat_param.select_param
|
||||
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatExecution,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
select_param=self.plugin_selector,
|
||||
chat_param=chat_param
|
||||
)
|
||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from chromadb.errors import NoIndexException
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||
|
||||
self.knowledge_space = select_param
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatKnowledge,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
self.space_context = self.get_space_context(self.knowledge_space)
|
||||
self.top_k = (
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Dict
|
||||
|
||||
from pilot.scene.base_chat import BaseChat
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.configs.config import Config
|
||||
@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
|
||||
|
||||
"""Number of results to return from the query"""
|
||||
|
||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
||||
def __init__(self, chat_param: Dict):
|
||||
""" """
|
||||
chat_param["chat_mode"] = ChatScene.ChatNormal
|
||||
super().__init__(
|
||||
chat_mode=ChatScene.ChatNormal,
|
||||
chat_session_id=chat_session_id,
|
||||
current_user_input=user_input,
|
||||
chat_param=chat_param,
|
||||
)
|
||||
|
||||
def generate_input_values(self):
|
||||
|
Loading…
Reference in New Issue
Block a user