mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 04:44:14 +00:00
feat: Support baichuan-7B model
This commit is contained in:
parent
01074660bc
commit
35d0c17ae8
@ -31,6 +31,9 @@ QUANTIZE_QLORA=True
|
||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||
# SMART_LLM_MODEL=vicuna-13b
|
||||
# FAST_LLM_MODEL=chatglm-6b
|
||||
## Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm", When we use the rest API provided by deployment frameworks like fastchat as a proxyllm,
|
||||
## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||
# PROXYLLM_BACKEND=
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
|
@ -11,7 +11,7 @@ cp .env.template .env
|
||||
LLM_MODEL=vicuna-13b
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
```
|
||||
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.
|
||||
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b, baichuan-7b, baichuan-13b
|
||||
|
||||
if you want use other model, such as chatglm-6b, you just need update .env config file.
|
||||
```
|
||||
|
@ -36,7 +36,19 @@ class StrictFormatter(Formatter):
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
class NoStrictFormatter(StrictFormatter):
|
||||
def check_unused_args(
|
||||
self,
|
||||
used_args: Sequence[Union[int, str]],
|
||||
args: Sequence,
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> None:
|
||||
"""Not check unused args"""
|
||||
pass
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
||||
no_strict_formatter = NoStrictFormatter()
|
||||
|
||||
|
||||
class MyEncoder(json.JSONEncoder):
|
||||
|
@ -131,6 +131,13 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||
self.PROXYLLM_BACKEND = None
|
||||
if self.LLM_MODEL == "proxyllm":
|
||||
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
|
||||
|
||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
||||
|
@ -51,6 +51,8 @@ LLM_MODEL_CONFIG = {
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -12,8 +12,6 @@ from transformers import (
|
||||
LlamaTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
from transformers.generation.utils import GenerationConfig
|
||||
|
||||
from pilot.configs.model_config import DEVICE
|
||||
from pilot.configs.config import Config
|
||||
|
||||
@ -285,8 +283,9 @@ class BaichuanAdapter(BaseLLMAdaper):
|
||||
return "baichuan" in model_path.lower()
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
# revision = from_pretrained_kwargs.get("revision", "main")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path, trust_remote_code=True, use_fast=False
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
|
@ -2,8 +2,6 @@
|
||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
Conversation prompt templates.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
@ -286,6 +284,21 @@ def get_conv_template(name: str) -> Conversation:
|
||||
return conv_templates[name].copy()
|
||||
|
||||
|
||||
# A template similar to the "one_shot" template above but remove the example.
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="zero_shot",
|
||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
||||
sep="\n### ",
|
||||
stop_str="###",
|
||||
)
|
||||
)
|
||||
|
||||
# llama2 template
|
||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||
register_conv_template(
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
from pilot.common.formatting import formatter
|
||||
from pilot.common.formatting import formatter, no_strict_formatter
|
||||
from pilot.out_parser.base import BaseOutputParser
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.prompts.example_base import ExampleSelector
|
||||
@ -24,8 +24,10 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"jinja2": jinja2_formatter,
|
||||
"f-string": lambda is_strict: formatter.format
|
||||
if is_strict
|
||||
else no_strict_formatter.format,
|
||||
"jinja2": lambda is_strict: jinja2_formatter,
|
||||
}
|
||||
|
||||
|
||||
@ -38,6 +40,8 @@ class PromptTemplate(BaseModel, ABC):
|
||||
template: Optional[str]
|
||||
"""The prompt template."""
|
||||
template_format: str = "f-string"
|
||||
"""strict template will check template args"""
|
||||
template_is_strict: bool = True
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format: Optional[str]
|
||||
"""default use stream out"""
|
||||
@ -68,10 +72,12 @@ class PromptTemplate(BaseModel, ABC):
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
||||
kwargs["response"] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template, **kwargs
|
||||
)
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
def add_goals(self, goal: str) -> None:
|
||||
self.goals.append(goal)
|
||||
|
@ -58,13 +58,26 @@ class PromptTemplateRegistry:
|
||||
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
|
||||
)
|
||||
|
||||
def get_prompt_template(self, scene_name: str, language: str, model_name: str):
|
||||
"""Get prompt template with scene name, language and model name"""
|
||||
def get_prompt_template(
|
||||
self,
|
||||
scene_name: str,
|
||||
language: str,
|
||||
model_name: str,
|
||||
proxyllm_backend: str = None,
|
||||
):
|
||||
"""Get prompt template with scene name, language and model name
|
||||
proxyllm_backend: see CFG.PROXYLLM_BACKEND
|
||||
"""
|
||||
scene_registry = self.registry[scene_name]
|
||||
registry = scene_registry.get(model_name)
|
||||
|
||||
print(
|
||||
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}"
|
||||
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name}, proxyllm_backend: {proxyllm_backend}, language: {language}"
|
||||
)
|
||||
registry = None
|
||||
if proxyllm_backend:
|
||||
registry = scene_registry.get(proxyllm_backend)
|
||||
if not registry:
|
||||
registry = scene_registry.get(model_name)
|
||||
if not registry:
|
||||
registry = scene_registry.get(_DEFAULT_MODEL_KEY)
|
||||
if not registry:
|
||||
|
@ -79,7 +79,10 @@ class BaseChat(ABC):
|
||||
# ]
|
||||
self.prompt_template: PromptTemplate = (
|
||||
CFG.prompt_template_registry.get_prompt_template(
|
||||
self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL
|
||||
self.chat_mode.value(),
|
||||
language=CFG.LANGUAGE,
|
||||
model_name=CFG.LLM_MODEL,
|
||||
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||
)
|
||||
)
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
|
@ -9,10 +9,7 @@ EXAMPLES = [
|
||||
{
|
||||
"type": "ai",
|
||||
"data": {
|
||||
"content": """{
|
||||
\"thoughts\": \"thought text\",
|
||||
\"sql\": \"SELECT city FROM user where user_name='test1'\",
|
||||
}""",
|
||||
"content": """{\n\"thoughts\": \"直接查询用户表中用户名为'test1'的记录即可\",\n\"sql\": \"SELECT city FROM user where user_name='test1'\"}""",
|
||||
"example": True,
|
||||
},
|
||||
},
|
||||
@ -24,10 +21,7 @@ EXAMPLES = [
|
||||
{
|
||||
"type": "ai",
|
||||
"data": {
|
||||
"content": """{
|
||||
\"thoughts\": \"thought text\",
|
||||
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
|
||||
}""",
|
||||
"content": """{\n\"thoughts\": \"根据订单表的用户名和用户表的用户名关联用户表和订单表,再通过用户表的城市为'成都'的过滤即可\",\n\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\"}""",
|
||||
"example": True,
|
||||
},
|
||||
},
|
||||
|
@ -43,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
@ -54,3 +54,4 @@ prompt = PromptTemplate(
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
from . import prompt_baichuan
|
||||
|
66
pilot/scene/chat_db/auto_execute/prompt_baichuan.py
Normal file
66
pilot/scene/chat_db/auto_execute/prompt_baichuan.py
Normal file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import json
|
||||
from pilot.prompts.prompt_new import PromptTemplate
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||
from pilot.common.schema import SeparatorStyle
|
||||
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = None
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
你是一个 SQL 专家,给你一个用户的问题,你会生成一条对应的 {dialect} 语法的 SQL 语句。
|
||||
|
||||
如果用户没有在问题中指定 sql 返回多少条数据,那么你生成的 sql 最多返回 {top_k} 条数据。
|
||||
你应该尽可能少地使用表。
|
||||
|
||||
已知表结构信息如下:
|
||||
{table_info}
|
||||
|
||||
注意:
|
||||
1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||
2. 不要查询不存在的列,注意哪一列位于哪张表中。
|
||||
3. 使用 json 格式回答,确保你的回答是必须是正确的 json 格式,并且能被 python 语言的 `json.loads` 库解析, 格式如下:
|
||||
{response}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"thoughts": "对用户说的想法摘要",
|
||||
"sql": "生成的将被执行的 SQL",
|
||||
}
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||
PROMPT_TEMPERATURE = 0.5
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_is_strict=False,
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=DbChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||
),
|
||||
# example_selector=sql_data_example,
|
||||
temperature=PROMPT_TEMPERATURE,
|
||||
)
|
||||
|
||||
CFG.prompt_template_registry.register(
|
||||
prompt,
|
||||
language=CFG.LANGUAGE,
|
||||
is_default=False,
|
||||
model_names=["baichuan-13b", "baichuan-7b"],
|
||||
)
|
@ -19,17 +19,20 @@ class BaseChatAdpter:
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return None
|
||||
|
||||
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]:
|
||||
def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
conv = self.get_conv_template()
|
||||
conv = self.get_conv_template(model_path)
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1}
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
print(
|
||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||
)
|
||||
return params, model_context
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
@ -84,6 +87,7 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
"""Get a chat generate func for a model"""
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
print(f"Get model path: {model_path} adapter {adapter}")
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||
@ -191,7 +195,7 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "llama-2" in model_path.lower()
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("llama-2")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
@ -204,8 +208,10 @@ class BaichuanChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "baichuan" in model_path.lower()
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
return get_conv_template("baichuan-chat")
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
if "chat" in model_path.lower():
|
||||
return get_conv_template("baichuan-chat")
|
||||
return get_conv_template("zero_shot")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.inference import generate_stream
|
||||
|
@ -78,7 +78,9 @@ class ModelWorker:
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(params)
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path
|
||||
)
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user