feat: Support baichuan-7B model

This commit is contained in:
FangYin Cheng 2023-07-25 18:39:07 +08:00
parent 01074660bc
commit 35d0c17ae8
15 changed files with 161 additions and 34 deletions

View File

@ -31,6 +31,9 @@ QUANTIZE_QLORA=True
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
# SMART_LLM_MODEL=vicuna-13b
# FAST_LLM_MODEL=chatglm-6b
## Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm", When we use the rest API provided by deployment frameworks like fastchat as a proxyllm,
## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
# PROXYLLM_BACKEND=
#*******************************************************************#

View File

@ -11,7 +11,7 @@ cp .env.template .env
LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000
```
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b, baichuan-7b, baichuan-13b
if you want use other model, such as chatglm-6b, you just need update .env config file.
```

View File

@ -36,7 +36,19 @@ class StrictFormatter(Formatter):
super().format(format_string, **dummy_inputs)
class NoStrictFormatter(StrictFormatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Not check unused args"""
pass
formatter = StrictFormatter()
no_strict_formatter = NoStrictFormatter()
class MyEncoder(json.JSONEncoder):

View File

@ -131,6 +131,13 @@ class Config(metaclass=Singleton):
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
self.PROXYLLM_BACKEND = None
if self.LLM_MODEL == "proxyllm":
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)

View File

@ -51,6 +51,8 @@ LLM_MODEL_CONFIG = {
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
}
# Load model config

View File

@ -12,8 +12,6 @@ from transformers import (
LlamaTokenizer,
BitsAndBytesConfig,
)
from transformers.generation.utils import GenerationConfig
from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
@ -285,8 +283,9 @@ class BaichuanAdapter(BaseLLMAdaper):
return "baichuan" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict):
# revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, use_fast=False
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,

View File

@ -2,8 +2,6 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
"""
import dataclasses
@ -286,6 +284,21 @@ def get_conv_template(name: str) -> Conversation:
return conv_templates[name].copy()
# A template similar to the "one_shot" template above but remove the example.
register_conv_template(
Conversation(
name="zero_shot",
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n### ",
stop_str="###",
)
)
# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter
from pilot.common.formatting import formatter, no_strict_formatter
from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector
@ -24,8 +24,10 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
"f-string": lambda is_strict: formatter.format
if is_strict
else no_strict_formatter.format,
"jinja2": lambda is_strict: jinja2_formatter,
}
@ -38,6 +40,8 @@ class PromptTemplate(BaseModel, ABC):
template: Optional[str]
"""The prompt template."""
template_format: str = "f-string"
"""strict template will check template args"""
template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str]
"""default use stream out"""
@ -68,10 +72,12 @@ class PromptTemplate(BaseModel, ABC):
"""Format the prompt with the inputs."""
if self.template:
if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4)
return DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template, **kwargs
kwargs["response"] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
def add_goals(self, goal: str) -> None:
self.goals.append(goal)

View File

@ -58,13 +58,26 @@ class PromptTemplateRegistry:
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
)
def get_prompt_template(self, scene_name: str, language: str, model_name: str):
"""Get prompt template with scene name, language and model name"""
def get_prompt_template(
self,
scene_name: str,
language: str,
model_name: str,
proxyllm_backend: str = None,
):
"""Get prompt template with scene name, language and model name
proxyllm_backend: see CFG.PROXYLLM_BACKEND
"""
scene_registry = self.registry[scene_name]
registry = scene_registry.get(model_name)
print(
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}"
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name}, proxyllm_backend: {proxyllm_backend}, language: {language}"
)
registry = None
if proxyllm_backend:
registry = scene_registry.get(proxyllm_backend)
if not registry:
registry = scene_registry.get(model_name)
if not registry:
registry = scene_registry.get(_DEFAULT_MODEL_KEY)
if not registry:

View File

@ -79,7 +79,10 @@ class BaseChat(ABC):
# ]
self.prompt_template: PromptTemplate = (
CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL
self.chat_mode.value(),
language=CFG.LANGUAGE,
model_name=CFG.LLM_MODEL,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
)
)
self.history_message: List[OnceConversation] = self.memory.messages()

View File

@ -9,10 +9,7 @@ EXAMPLES = [
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT city FROM user where user_name='test1'\",
}""",
"content": """{\n\"thoughts\": \"直接查询用户表中用户名为'test1'的记录即可\",\n\"sql\": \"SELECT city FROM user where user_name='test1'\"}""",
"example": True,
},
},
@ -24,10 +21,7 @@ EXAMPLES = [
{
"type": "ai",
"data": {
"content": """{
\"thoughts\": \"thought text\",
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"content": """{\n\"thoughts\": \"根据订单表的用户名和用户表的用户名关联用户表和订单表,再通过用户表的城市为'成都'的过滤即可\",\n\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\"}""",
"example": True,
},
},

View File

@ -43,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
@ -54,3 +54,4 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)
from . import prompt_baichuan

View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example
CFG = Config()
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
你是一个 SQL 专家给你一个用户的问题你会生成一条对应的 {dialect} 语法的 SQL 语句
如果用户没有在问题中指定 sql 返回多少条数据那么你生成的 sql 最多返回 {top_k} 条数据
你应该尽可能少地使用表
已知表结构信息如下
{table_info}
注意
1. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
2. 不要查询不存在的列注意哪一列位于哪张表中
3. 使用 json 格式回答确保你的回答是必须是正确的 json 格式并且能被 python 语言的 `json.loads` 库解析, 格式如下
{response}
"""
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "对用户说的想法摘要",
"sql": "生成的将被执行的 SQL",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_is_strict=False,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(
prompt,
language=CFG.LANGUAGE,
is_default=False,
model_names=["baichuan-13b", "baichuan-7b"],
)

View File

@ -19,17 +19,20 @@ class BaseChatAdpter:
"""Return the generate stream handler func"""
pass
def get_conv_template(self) -> Conversation:
def get_conv_template(self, model_path: str) -> Conversation:
return None
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]:
def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_conv_template()
conv = self.get_conv_template(model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1}
if not conv or not messages:
# Nothing to do
print(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
conv = conv.copy()
system_messages = []
@ -84,6 +87,7 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
print(f"Get model path: {model_path} adapter {adapter}")
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
@ -191,7 +195,7 @@ class Llama2ChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "llama-2" in model_path.lower()
def get_conv_template(self) -> Conversation:
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
def get_generate_stream_func(self):
@ -204,8 +208,10 @@ class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "baichuan" in model_path.lower()
def get_conv_template(self) -> Conversation:
def get_conv_template(self, model_path: str) -> Conversation:
if "chat" in model_path.lower():
return get_conv_template("baichuan-chat")
return get_conv_template("zero_shot")
def get_generate_stream_func(self):
from pilot.model.inference import generate_stream

View File

@ -78,7 +78,9 @@ class ModelWorker:
def generate_stream_gate(self, params):
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(params)
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):