mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
feat:multi-llm add model param
This commit is contained in:
commit
1a2bf96767
4
.github/release-drafter.yml
vendored
4
.github/release-drafter.yml
vendored
@ -7,8 +7,8 @@ categories:
|
|||||||
labels: enhancement
|
labels: enhancement
|
||||||
- title: 🐞 Bug fixes
|
- title: 🐞 Bug fixes
|
||||||
labels: fix
|
labels: fix
|
||||||
- title: ⚠️ Deprecations
|
# - title: ⚠️ Deprecations
|
||||||
labels: deprecation
|
# labels: deprecation
|
||||||
- title: 🛠️ Other improvements
|
- title: 🛠️ Other improvements
|
||||||
labels:
|
labels:
|
||||||
- documentation
|
- documentation
|
||||||
|
20
.github/workflows/pr-labeler.yml
vendored
Normal file
20
.github/workflows/pr-labeler.yml
vendored
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
name: Pull request labeler
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, edited]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
main:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Label pull request
|
||||||
|
uses: release-drafter/release-drafter@v5
|
||||||
|
with:
|
||||||
|
disable-releaser: true
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
22
.github/workflows/release-drafter.yml
vendored
Normal file
22
.github/workflows/release-drafter.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
name: Update draft releases
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
main:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Draft DB-GPT release
|
||||||
|
uses: release-drafter/release-drafter@v5
|
||||||
|
with:
|
||||||
|
config-name: release-drafter.yml
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
23
Dockerfile
23
Dockerfile
@ -1,23 +0,0 @@
|
|||||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
git \
|
|
||||||
python3 \
|
|
||||||
pip
|
|
||||||
|
|
||||||
# upgrade pip
|
|
||||||
RUN pip3 install --upgrade pip
|
|
||||||
|
|
||||||
COPY ./requirements.txt /app/requirements.txt
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
RUN python3 -m spacy download zh_core_web_sm
|
|
||||||
|
|
||||||
|
|
||||||
COPY . /app
|
|
||||||
|
|
||||||
EXPOSE 7860
|
|
||||||
EXPOSE 8000
|
|
@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
|
|||||||
- Unified vector storage/indexing of knowledge base
|
- Unified vector storage/indexing of knowledge base
|
||||||
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
|
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
|
||||||
- Multi LLMs Support, Supports multiple large language models, currently supporting
|
- Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||||
|
- 🔥 Baichuan2(7b,13b)
|
||||||
- 🔥 Vicuna-v1.5(7b,13b)
|
- 🔥 Vicuna-v1.5(7b,13b)
|
||||||
- 🔥 llama-2(7b,13b,70b)
|
- 🔥 llama-2(7b,13b,70b)
|
||||||
- WizardLM-v1.2(13b)
|
- WizardLM-v1.2(13b)
|
||||||
|
@ -30,12 +30,7 @@
|
|||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**Documents
|
[**English**](README.md)|[**Discord**](https://discord.gg/FMGwbRQrM) |[**文档**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**社区**](https://github.com/eosphoros-ai/community)
|
||||||
**](https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/)|[**微信
|
|
||||||
**](https://github.com/csunny/DB-GPT/blob/main/README.zh.md#%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC)|[**Community
|
|
||||||
**](https://github.com/eosphoros-ai/community)
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## DB-GPT 是什么?
|
## DB-GPT 是什么?
|
||||||
|
|
||||||
@ -119,6 +114,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
|
|||||||
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
|
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
|
||||||
- 多模型支持
|
- 多模型支持
|
||||||
- 支持多种大语言模型, 当前已支持如下模型:
|
- 支持多种大语言模型, 当前已支持如下模型:
|
||||||
|
- 🔥 Baichuan2(7b,13b)
|
||||||
- 🔥 Vicuna-v1.5(7b,13b)
|
- 🔥 Vicuna-v1.5(7b,13b)
|
||||||
- 🔥 llama-2(7b,13b,70b)
|
- 🔥 llama-2(7b,13b,70b)
|
||||||
- WizardLM-v1.2(13b)
|
- WizardLM-v1.2(13b)
|
||||||
|
19
SECURITY.md
19
SECURITY.md
@ -1,19 +0,0 @@
|
|||||||
# Security Policy
|
|
||||||
|
|
||||||
## Supported Versions
|
|
||||||
|
|
||||||
Use this section to tell people about which versions of your project are
|
|
||||||
currently being supported with security updates.
|
|
||||||
|
|
||||||
| Version | Supported |
|
|
||||||
| ------- | ------------------ |
|
|
||||||
| v0.0.4 | :no new features |
|
|
||||||
| v0.0.3 | :documents QA |
|
|
||||||
| v0.0.2 | :sql generator |
|
|
||||||
| v0.0.1 | :minst runable |
|
|
||||||
|
|
||||||
## Reporting a Vulnerability
|
|
||||||
|
|
||||||
Use this section to tell people how to report a vulnerability.
|
|
||||||
|
|
||||||
we will build somethings small.
|
|
Binary file not shown.
Before Width: | Height: | Size: 380 KiB After Width: | Height: | Size: 256 KiB |
@ -77,6 +77,8 @@ LLM_MODEL_CONFIG = {
|
|||||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||||
|
"baichuan2-7b": os.path.join(MODEL_PATH, "Baichuan2-7B-Chat"),
|
||||||
|
"baichuan2-13b": os.path.join(MODEL_PATH, "Baichuan2-13B-Chat"),
|
||||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||||
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
"llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
|
||||||
|
@ -54,6 +54,10 @@ class ConversationVo(BaseModel):
|
|||||||
chat scene select param
|
chat scene select param
|
||||||
"""
|
"""
|
||||||
select_param: str = None
|
select_param: str = None
|
||||||
|
"""
|
||||||
|
llm model name
|
||||||
|
"""
|
||||||
|
model_name: str = None
|
||||||
|
|
||||||
|
|
||||||
class MessageVo(BaseModel):
|
class MessageVo(BaseModel):
|
||||||
|
@ -2,7 +2,7 @@ import datetime
|
|||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
@ -33,12 +33,12 @@ class BaseChat(ABC):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
|
self, chat_param: Dict
|
||||||
):
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_param["chat_session_id"]
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_param["chat_mode"]
|
||||||
self.current_user_input: str = current_user_input
|
self.current_user_input: str = chat_param["current_user_input"]
|
||||||
self.llm_model = CFG.LLM_MODEL
|
self.llm_model = chat_param["model_name"]
|
||||||
self.llm_echo = False
|
self.llm_echo = False
|
||||||
|
|
||||||
### load prompt template
|
### load prompt template
|
||||||
@ -55,14 +55,14 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
### can configurable storage methods
|
### can configurable storage methods
|
||||||
self.memory = DuckdbHistoryMemory(chat_session_id)
|
self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"])
|
||||||
|
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value())
|
self.current_message: OnceConversation = OnceConversation(self.chat_mode.value())
|
||||||
if select_param:
|
if chat_param["select_param"]:
|
||||||
if len(chat_mode.param_types()) > 0:
|
if len(self.chat_mode.param_types()) > 0:
|
||||||
self.current_message.param_type = chat_mode.param_types()[0]
|
self.current_message.param_type = self.chat_mode.param_types()[0]
|
||||||
self.current_message.param_value = select_param
|
self.current_message.param_value = chat_param["select_param"]
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
@ -23,28 +23,23 @@ class ChatDashboard(BaseChat):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_session_id,
|
chat_param: Dict
|
||||||
user_input,
|
|
||||||
select_param: str = "",
|
|
||||||
report_name: str = "report",
|
|
||||||
):
|
):
|
||||||
""" """
|
""" """
|
||||||
self.db_name = select_param
|
self.db_name = chat_param["select_param"]
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatDashboard
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatDashboard,
|
chat_param=chat_param
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=self.db_name,
|
|
||||||
)
|
)
|
||||||
if not self.db_name:
|
if not self.db_name:
|
||||||
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!")
|
||||||
self.db_name = self.db_name
|
self.db_name = self.db_name
|
||||||
self.report_name = report_name
|
self.report_name = chat_param["report_name"] or "report"
|
||||||
|
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
|
|
||||||
self.top_k: int = 5
|
self.top_k: int = 5
|
||||||
self.dashboard_template = self.__load_dashboard_template(report_name)
|
self.dashboard_template = self.__load_dashboard_template(self.report_name)
|
||||||
|
|
||||||
def __load_dashboard_template(self, template_name):
|
def __load_dashboard_template(self, template_name):
|
||||||
current_dir = os.getcwd()
|
current_dir = os.getcwd()
|
||||||
|
@ -22,24 +22,23 @@ class ChatExcel(BaseChat):
|
|||||||
chat_scene: str = ChatScene.ChatExcel.value()
|
chat_scene: str = ChatScene.ChatExcel.value()
|
||||||
chat_retention_rounds = 1
|
chat_retention_rounds = 1
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
def __init__(self, chat_param: Dict):
|
||||||
chat_mode = ChatScene.ChatExcel
|
chat_mode = ChatScene.ChatExcel
|
||||||
|
|
||||||
self.select_param = select_param
|
self.select_param = chat_param["select_param"]
|
||||||
if has_path(select_param):
|
self.model_name = chat_param["model_name"]
|
||||||
self.excel_reader = ExcelReader(select_param)
|
chat_param["chat_mode"] = ChatScene.ChatExcel
|
||||||
|
if has_path(self.select_param):
|
||||||
|
self.excel_reader = ExcelReader(self.select_param)
|
||||||
else:
|
else:
|
||||||
self.excel_reader = ExcelReader(
|
self.excel_reader = ExcelReader(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
|
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=chat_mode,
|
chat_param=chat_param
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=select_param,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
def _generate_command_string(self, command: Dict[str, Any]) -> str:
|
||||||
@ -85,6 +84,7 @@ class ChatExcel(BaseChat):
|
|||||||
"parent_mode": self.chat_mode,
|
"parent_mode": self.chat_mode,
|
||||||
"select_param": self.excel_reader.excel_file_name,
|
"select_param": self.excel_reader.excel_file_name,
|
||||||
"excel_reader": self.excel_reader,
|
"excel_reader": self.excel_reader,
|
||||||
|
"model_name": self.model_name,
|
||||||
}
|
}
|
||||||
learn_chat = ExcelLearning(**chat_param)
|
learn_chat = ExcelLearning(**chat_param)
|
||||||
result = await learn_chat.nostream_call()
|
result = await learn_chat.nostream_call()
|
||||||
|
@ -30,16 +30,21 @@ class ExcelLearning(BaseChat):
|
|||||||
parent_mode: Any = None,
|
parent_mode: Any = None,
|
||||||
select_param: str = None,
|
select_param: str = None,
|
||||||
excel_reader: Any = None,
|
excel_reader: Any = None,
|
||||||
|
model_name: str = None,
|
||||||
):
|
):
|
||||||
chat_mode = ChatScene.ExcelLearning
|
chat_mode = ChatScene.ExcelLearning
|
||||||
""" """
|
""" """
|
||||||
self.excel_file_path = select_param
|
self.excel_file_path = select_param
|
||||||
self.excel_reader = excel_reader
|
self.excel_reader = excel_reader
|
||||||
|
chat_param = {
|
||||||
|
"chat_mode": chat_mode,
|
||||||
|
"chat_session_id": chat_session_id,
|
||||||
|
"current_user_input": user_input,
|
||||||
|
"select_param": select_param,
|
||||||
|
"model_name": model_name,
|
||||||
|
}
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=chat_mode,
|
chat_param=chat_param
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=select_param,
|
|
||||||
)
|
)
|
||||||
if parent_mode:
|
if parent_mode:
|
||||||
self.current_message.chat_mode = parent_mode.value()
|
self.current_message.chat_mode = parent_mode.value()
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -12,15 +14,13 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
def __init__(self, chat_param: Dict):
|
||||||
chat_mode = ChatScene.ChatWithDbExecute
|
chat_mode = ChatScene.ChatWithDbExecute
|
||||||
self.db_name = select_param
|
self.db_name = chat_param["select_param"]
|
||||||
|
chat_param["chat_mode"] = chat_mode
|
||||||
""" """
|
""" """
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=chat_mode,
|
chat_param=chat_param,
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=self.db_name,
|
|
||||||
)
|
)
|
||||||
if not self.db_name:
|
if not self.db_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
@ -12,14 +14,12 @@ class ChatWithDbQA(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = ""):
|
def __init__(self, chat_param: Dict):
|
||||||
""" """
|
""" """
|
||||||
self.db_name = select_param
|
self.db_name = chat_param["select_param"]
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatWithDbQA
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatWithDbQA,
|
chat_param=chat_param
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=self.db_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.db_name:
|
if self.db_name:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
@ -15,13 +15,11 @@ class ChatWithPlugin(BaseChat):
|
|||||||
plugins_prompt_generator: PluginPromptGenerator
|
plugins_prompt_generator: PluginPromptGenerator
|
||||||
select_plugin: str = None
|
select_plugin: str = None
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
def __init__(self, chat_param: Dict):
|
||||||
self.plugin_selector = select_param
|
self.plugin_selector = chat_param.select_param
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatExecution
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatExecution,
|
chat_param=chat_param
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
select_param=self.plugin_selector,
|
|
||||||
)
|
)
|
||||||
self.plugins_prompt_generator = PluginPromptGenerator()
|
self.plugins_prompt_generator = PluginPromptGenerator()
|
||||||
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
self.plugins_prompt_generator.command_registry = CFG.command_registry
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from chromadb.errors import NoIndexException
|
from chromadb.errors import NoIndexException
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
@ -20,15 +22,14 @@ class ChatKnowledge(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
def __init__(self, chat_param: Dict):
|
||||||
""" """
|
""" """
|
||||||
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
from pilot.embedding_engine.embedding_engine import EmbeddingEngine
|
||||||
|
|
||||||
self.knowledge_space = select_param
|
self.knowledge_space = chat_param["select_param"]
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatKnowledge,
|
chat_param=chat_param,
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
)
|
)
|
||||||
self.space_context = self.get_space_context(self.knowledge_space)
|
self.space_context = self.get_space_context(self.knowledge_space)
|
||||||
self.top_k = (
|
self.top_k = (
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
from pilot.scene.base_chat import BaseChat
|
from pilot.scene.base_chat import BaseChat
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
@ -12,12 +14,11 @@ class ChatNormal(BaseChat):
|
|||||||
|
|
||||||
"""Number of results to return from the query"""
|
"""Number of results to return from the query"""
|
||||||
|
|
||||||
def __init__(self, chat_session_id, user_input, select_param: str = None):
|
def __init__(self, chat_param: Dict):
|
||||||
""" """
|
""" """
|
||||||
|
chat_param["chat_mode"] = ChatScene.ChatNormal
|
||||||
super().__init__(
|
super().__init__(
|
||||||
chat_mode=ChatScene.ChatNormal,
|
chat_param=chat_param,
|
||||||
chat_session_id=chat_session_id,
|
|
||||||
current_user_input=user_input,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user