feat: Support baichuan-7B model

This commit is contained in:
FangYin Cheng 2023-07-25 18:39:07 +08:00
parent 01074660bc
commit 35d0c17ae8
15 changed files with 161 additions and 34 deletions

View File

@ -31,6 +31,9 @@ QUANTIZE_QLORA=True
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
# SMART_LLM_MODEL=vicuna-13b # SMART_LLM_MODEL=vicuna-13b
# FAST_LLM_MODEL=chatglm-6b # FAST_LLM_MODEL=chatglm-6b
## Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm", When we use the rest API provided by deployment frameworks like fastchat as a proxyllm,
## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
# PROXYLLM_BACKEND=
#*******************************************************************# #*******************************************************************#

View File

@ -11,7 +11,7 @@ cp .env.template .env
LLM_MODEL=vicuna-13b LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000 MODEL_SERVER=http://127.0.0.1:8000
``` ```
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b. now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b, baichuan-7b, baichuan-13b
if you want use other model, such as chatglm-6b, you just need update .env config file. if you want use other model, such as chatglm-6b, you just need update .env config file.
``` ```

View File

@ -36,7 +36,19 @@ class StrictFormatter(Formatter):
super().format(format_string, **dummy_inputs) super().format(format_string, **dummy_inputs)
class NoStrictFormatter(StrictFormatter):
def check_unused_args(
self,
used_args: Sequence[Union[int, str]],
args: Sequence,
kwargs: Mapping[str, Any],
) -> None:
"""Not check unused args"""
pass
formatter = StrictFormatter() formatter = StrictFormatter()
no_strict_formatter = NoStrictFormatter()
class MyEncoder(json.JSONEncoder): class MyEncoder(json.JSONEncoder):

View File

@ -131,6 +131,13 @@ class Config(metaclass=Singleton):
### LLM Model Service Configuration ### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
self.PROXYLLM_BACKEND = None
if self.LLM_MODEL == "proxyllm":
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000) self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)

View File

@ -51,6 +51,8 @@ LLM_MODEL_CONFIG = {
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"), "baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
} }
# Load model config # Load model config

View File

@ -12,8 +12,6 @@ from transformers import (
LlamaTokenizer, LlamaTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
) )
from transformers.generation.utils import GenerationConfig
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config from pilot.configs.config import Config
@ -285,8 +283,9 @@ class BaichuanAdapter(BaseLLMAdaper):
return "baichuan" in model_path.lower() return "baichuan" in model_path.lower()
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# revision = from_pretrained_kwargs.get("revision", "main") tokenizer = AutoTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model_path, trust_remote_code=True, use_fast=False
)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
trust_remote_code=True, trust_remote_code=True,

View File

@ -2,8 +2,6 @@
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates. Conversation prompt templates.
""" """
import dataclasses import dataclasses
@ -286,6 +284,21 @@ def get_conv_template(name: str) -> Conversation:
return conv_templates[name].copy() return conv_templates[name].copy()
# A template similar to the "one_shot" template above but remove the example.
register_conv_template(
Conversation(
name="zero_shot",
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
sep="\n### ",
stop_str="###",
)
)
# llama2 template # llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template( register_conv_template(

View File

@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from pilot.common.formatting import formatter from pilot.common.formatting import formatter, no_strict_formatter
from pilot.out_parser.base import BaseOutputParser from pilot.out_parser.base import BaseOutputParser
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
from pilot.prompts.example_base import ExampleSelector from pilot.prompts.example_base import ExampleSelector
@ -24,8 +24,10 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format, "f-string": lambda is_strict: formatter.format
"jinja2": jinja2_formatter, if is_strict
else no_strict_formatter.format,
"jinja2": lambda is_strict: jinja2_formatter,
} }
@ -38,6 +40,8 @@ class PromptTemplate(BaseModel, ABC):
template: Optional[str] template: Optional[str]
"""The prompt template.""" """The prompt template."""
template_format: str = "f-string" template_format: str = "f-string"
"""strict template will check template args"""
template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str] response_format: Optional[str]
"""default use stream out""" """default use stream out"""
@ -68,10 +72,12 @@ class PromptTemplate(BaseModel, ABC):
"""Format the prompt with the inputs.""" """Format the prompt with the inputs."""
if self.template: if self.template:
if self.response_format: if self.response_format:
kwargs["response"] = json.dumps(self.response_format, indent=4) kwargs["response"] = json.dumps(
return DEFAULT_FORMATTER_MAPPING[self.template_format]( self.response_format, ensure_ascii=False, indent=4
self.template, **kwargs
) )
return DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
def add_goals(self, goal: str) -> None: def add_goals(self, goal: str) -> None:
self.goals.append(goal) self.goals.append(goal)

View File

@ -58,13 +58,26 @@ class PromptTemplateRegistry:
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY] scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
) )
def get_prompt_template(self, scene_name: str, language: str, model_name: str): def get_prompt_template(
"""Get prompt template with scene name, language and model name""" self,
scene_name: str,
language: str,
model_name: str,
proxyllm_backend: str = None,
):
"""Get prompt template with scene name, language and model name
proxyllm_backend: see CFG.PROXYLLM_BACKEND
"""
scene_registry = self.registry[scene_name] scene_registry = self.registry[scene_name]
registry = scene_registry.get(model_name)
print( print(
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}" f"Get prompt template of scene_name: {scene_name} with model_name: {model_name}, proxyllm_backend: {proxyllm_backend}, language: {language}"
) )
registry = None
if proxyllm_backend:
registry = scene_registry.get(proxyllm_backend)
if not registry:
registry = scene_registry.get(model_name)
if not registry: if not registry:
registry = scene_registry.get(_DEFAULT_MODEL_KEY) registry = scene_registry.get(_DEFAULT_MODEL_KEY)
if not registry: if not registry:

View File

@ -79,7 +79,10 @@ class BaseChat(ABC):
# ] # ]
self.prompt_template: PromptTemplate = ( self.prompt_template: PromptTemplate = (
CFG.prompt_template_registry.get_prompt_template( CFG.prompt_template_registry.get_prompt_template(
self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL self.chat_mode.value(),
language=CFG.LANGUAGE,
model_name=CFG.LLM_MODEL,
proxyllm_backend=CFG.PROXYLLM_BACKEND,
) )
) )
self.history_message: List[OnceConversation] = self.memory.messages() self.history_message: List[OnceConversation] = self.memory.messages()

View File

@ -9,10 +9,7 @@ EXAMPLES = [
{ {
"type": "ai", "type": "ai",
"data": { "data": {
"content": """{ "content": """{\n\"thoughts\": \"直接查询用户表中用户名为'test1'的记录即可\",\n\"sql\": \"SELECT city FROM user where user_name='test1'\"}""",
\"thoughts\": \"thought text\",
\"sql\": \"SELECT city FROM user where user_name='test1'\",
}""",
"example": True, "example": True,
}, },
}, },
@ -24,10 +21,7 @@ EXAMPLES = [
{ {
"type": "ai", "type": "ai",
"data": { "data": {
"content": """{ "content": """{\n\"thoughts\": \"根据订单表的用户名和用户表的用户名关联用户表和订单表,再通过用户表的城市为'成都'的过滤即可\",\n\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\"}""",
\"thoughts\": \"thought text\",
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
}""",
"example": True, "example": True,
}, },
}, },

View File

@ -43,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(), template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"], input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4), response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE, template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_NEED_STREAM_OUT,
@ -54,3 +54,4 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)
from . import prompt_baichuan

View File

@ -0,0 +1,66 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_db.auto_execute.example import sql_data_example
CFG = Config()
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """
你是一个 SQL 专家给你一个用户的问题你会生成一条对应的 {dialect} 语法的 SQL 语句
如果用户没有在问题中指定 sql 返回多少条数据那么你生成的 sql 最多返回 {top_k} 条数据
你应该尽可能少地使用表
已知表结构信息如下
{table_info}
注意
1. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
2. 不要查询不存在的列注意哪一列位于哪张表中
3. 使用 json 格式回答确保你的回答是必须是正确的 json 格式并且能被 python 语言的 `json.loads` 库解析, 格式如下
{response}
"""
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "对用户说的想法摘要",
"sql": "生成的将被执行的 SQL",
}
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_is_strict=False,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=DbChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(
prompt,
language=CFG.LANGUAGE,
is_default=False,
model_names=["baichuan-13b", "baichuan-7b"],
)

View File

@ -19,17 +19,20 @@ class BaseChatAdpter:
"""Return the generate stream handler func""" """Return the generate stream handler func"""
pass pass
def get_conv_template(self) -> Conversation: def get_conv_template(self, model_path: str) -> Conversation:
return None return None
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]: def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
"""Params adaptation""" """Params adaptation"""
conv = self.get_conv_template() conv = self.get_conv_template(model_path)
messages = params.get("messages") messages = params.get("messages")
# Some model scontext to dbgpt server # Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1} model_context = {"prompt_echo_len_char": -1}
if not conv or not messages: if not conv or not messages:
# Nothing to do # Nothing to do
print(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context return params, model_context
conv = conv.copy() conv = conv.copy()
system_messages = [] system_messages = []
@ -84,6 +87,7 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model""" """Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters: for adapter in llm_model_chat_adapters:
if adapter.match(model_path): if adapter.match(model_path):
print(f"Get model path: {model_path} adapter {adapter}")
return adapter return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}") raise ValueError(f"Invalid model for chat adapter {model_path}")
@ -191,7 +195,7 @@ class Llama2ChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "llama-2" in model_path.lower() return "llama-2" in model_path.lower()
def get_conv_template(self) -> Conversation: def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2") return get_conv_template("llama-2")
def get_generate_stream_func(self): def get_generate_stream_func(self):
@ -204,8 +208,10 @@ class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "baichuan" in model_path.lower() return "baichuan" in model_path.lower()
def get_conv_template(self) -> Conversation: def get_conv_template(self, model_path: str) -> Conversation:
if "chat" in model_path.lower():
return get_conv_template("baichuan-chat") return get_conv_template("baichuan-chat")
return get_conv_template("zero_shot")
def get_generate_stream_func(self): def get_generate_stream_func(self):
from pilot.model.inference import generate_stream from pilot.model.inference import generate_stream

View File

@ -78,7 +78,9 @@ class ModelWorker:
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
# params adaptation # params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(params) params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path
)
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):