diff --git a/.env.template b/.env.template index 3ff2d9077..7212e7c6d 100644 --- a/.env.template +++ b/.env.template @@ -31,6 +31,9 @@ QUANTIZE_QLORA=True ## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b) # SMART_LLM_MODEL=vicuna-13b # FAST_LLM_MODEL=chatglm-6b +## Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm", When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, +## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. +# PROXYLLM_BACKEND= #*******************************************************************# diff --git a/docs/modules/llms.md b/docs/modules/llms.md index a4baf1807..1e21ecc74 100644 --- a/docs/modules/llms.md +++ b/docs/modules/llms.md @@ -11,7 +11,7 @@ cp .env.template .env LLM_MODEL=vicuna-13b MODEL_SERVER=http://127.0.0.1:8000 ``` -now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b. +now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b, baichuan-7b, baichuan-13b if you want use other model, such as chatglm-6b, you just need update .env config file. ``` diff --git a/pilot/common/formatting.py b/pilot/common/formatting.py index 6bf10c1b2..51f374314 100644 --- a/pilot/common/formatting.py +++ b/pilot/common/formatting.py @@ -36,7 +36,19 @@ class StrictFormatter(Formatter): super().format(format_string, **dummy_inputs) +class NoStrictFormatter(StrictFormatter): + def check_unused_args( + self, + used_args: Sequence[Union[int, str]], + args: Sequence, + kwargs: Mapping[str, Any], + ) -> None: + """Not check unused args""" + pass + + formatter = StrictFormatter() +no_strict_formatter = NoStrictFormatter() class MyEncoder(json.JSONEncoder): diff --git a/pilot/configs/config.py b/pilot/configs/config.py index d9dc5e75e..932116b96 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -131,6 +131,13 @@ class Config(metaclass=Singleton): ### LLM Model Service Configuration self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") + ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" + ### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy. + ### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. + self.PROXYLLM_BACKEND = None + if self.LLM_MODEL == "proxyllm": + self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND") + self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5)) self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096)) self.MODEL_PORT = os.getenv("MODEL_PORT", 8000) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index d719ad3dc..70b5280ac 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -51,6 +51,8 @@ LLM_MODEL_CONFIG = { "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), "baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"), + # please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b" + "baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"), } # Load model config diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index ebe7b82d5..900d51d4a 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -12,8 +12,6 @@ from transformers import ( LlamaTokenizer, BitsAndBytesConfig, ) -from transformers.generation.utils import GenerationConfig - from pilot.configs.model_config import DEVICE from pilot.configs.config import Config @@ -285,8 +283,9 @@ class BaichuanAdapter(BaseLLMAdaper): return "baichuan" in model_path.lower() def loader(self, model_path: str, from_pretrained_kwargs: dict): - # revision = from_pretrained_kwargs.get("revision", "main") - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=False + ) model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index fa57b2af5..11ca03ed8 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -2,8 +2,6 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py Conversation prompt templates. - - """ import dataclasses @@ -286,6 +284,21 @@ def get_conv_template(name: str) -> Conversation: return conv_templates[name].copy() +# A template similar to the "one_shot" template above but remove the example. +register_conv_template( + Conversation( + name="zero_shot", + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + # llama2 template # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 register_conv_template( diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 78e1585ea..84943050c 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from pydantic import BaseModel, Extra, Field, root_validator -from pilot.common.formatting import formatter +from pilot.common.formatting import formatter, no_strict_formatter from pilot.out_parser.base import BaseOutputParser from pilot.common.schema import SeparatorStyle from pilot.prompts.example_base import ExampleSelector @@ -24,8 +24,10 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { - "f-string": formatter.format, - "jinja2": jinja2_formatter, + "f-string": lambda is_strict: formatter.format + if is_strict + else no_strict_formatter.format, + "jinja2": lambda is_strict: jinja2_formatter, } @@ -38,6 +40,8 @@ class PromptTemplate(BaseModel, ABC): template: Optional[str] """The prompt template.""" template_format: str = "f-string" + """strict template will check template args""" + template_is_strict: bool = True """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" response_format: Optional[str] """default use stream out""" @@ -68,10 +72,12 @@ class PromptTemplate(BaseModel, ABC): """Format the prompt with the inputs.""" if self.template: if self.response_format: - kwargs["response"] = json.dumps(self.response_format, indent=4) + kwargs["response"] = json.dumps( + self.response_format, ensure_ascii=False, indent=4 + ) return DEFAULT_FORMATTER_MAPPING[self.template_format]( - self.template, **kwargs - ) + self.template_is_strict + )(self.template, **kwargs) def add_goals(self, goal: str) -> None: self.goals.append(goal) diff --git a/pilot/prompts/prompt_registry.py b/pilot/prompts/prompt_registry.py index 0372e92b3..1e4aeb135 100644 --- a/pilot/prompts/prompt_registry.py +++ b/pilot/prompts/prompt_registry.py @@ -58,13 +58,26 @@ class PromptTemplateRegistry: scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY] ) - def get_prompt_template(self, scene_name: str, language: str, model_name: str): - """Get prompt template with scene name, language and model name""" + def get_prompt_template( + self, + scene_name: str, + language: str, + model_name: str, + proxyllm_backend: str = None, + ): + """Get prompt template with scene name, language and model name + proxyllm_backend: see CFG.PROXYLLM_BACKEND + """ scene_registry = self.registry[scene_name] - registry = scene_registry.get(model_name) + print( - f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}" + f"Get prompt template of scene_name: {scene_name} with model_name: {model_name}, proxyllm_backend: {proxyllm_backend}, language: {language}" ) + registry = None + if proxyllm_backend: + registry = scene_registry.get(proxyllm_backend) + if not registry: + registry = scene_registry.get(model_name) if not registry: registry = scene_registry.get(_DEFAULT_MODEL_KEY) if not registry: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index e4520931f..0c8abac4f 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -79,7 +79,10 @@ class BaseChat(ABC): # ] self.prompt_template: PromptTemplate = ( CFG.prompt_template_registry.get_prompt_template( - self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL + self.chat_mode.value(), + language=CFG.LANGUAGE, + model_name=CFG.LLM_MODEL, + proxyllm_backend=CFG.PROXYLLM_BACKEND, ) ) self.history_message: List[OnceConversation] = self.memory.messages() diff --git a/pilot/scene/chat_db/auto_execute/example.py b/pilot/scene/chat_db/auto_execute/example.py index df45b4ed3..ece464cdd 100644 --- a/pilot/scene/chat_db/auto_execute/example.py +++ b/pilot/scene/chat_db/auto_execute/example.py @@ -9,10 +9,7 @@ EXAMPLES = [ { "type": "ai", "data": { - "content": """{ - \"thoughts\": \"thought text\", - \"sql\": \"SELECT city FROM user where user_name='test1'\", - }""", + "content": """{\n\"thoughts\": \"直接查询用户表中用户名为'test1'的记录即可\",\n\"sql\": \"SELECT city FROM user where user_name='test1'\"}""", "example": True, }, }, @@ -24,10 +21,7 @@ EXAMPLES = [ { "type": "ai", "data": { - "content": """{ - \"thoughts\": \"thought text\", - \"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\", - }""", + "content": """{\n\"thoughts\": \"根据订单表的用户名和用户表的用户名关联用户表和订单表,再通过用户表的城市为'成都'的过滤即可\",\n\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\"}""", "example": True, }, }, diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 60b748224..6ee691a88 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -43,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5 prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbExecute.value(), input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4), + response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT, @@ -54,3 +54,4 @@ prompt = PromptTemplate( temperature=PROMPT_TEMPERATURE, ) CFG.prompt_template_registry.register(prompt, is_default=True) +from . import prompt_baichuan diff --git a/pilot/scene/chat_db/auto_execute/prompt_baichuan.py b/pilot/scene/chat_db/auto_execute/prompt_baichuan.py new file mode 100644 index 000000000..ed0c2c8b2 --- /dev/null +++ b/pilot/scene/chat_db/auto_execute/prompt_baichuan.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import json +from pilot.prompts.prompt_new import PromptTemplate +from pilot.configs.config import Config +from pilot.scene.base import ChatScene +from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction +from pilot.common.schema import SeparatorStyle +from pilot.scene.chat_db.auto_execute.example import sql_data_example + +CFG = Config() + +PROMPT_SCENE_DEFINE = None + +_DEFAULT_TEMPLATE = """ +你是一个 SQL 专家,给你一个用户的问题,你会生成一条对应的 {dialect} 语法的 SQL 语句。 + +如果用户没有在问题中指定 sql 返回多少条数据,那么你生成的 sql 最多返回 {top_k} 条数据。 +你应该尽可能少地使用表。 + +已知表结构信息如下: +{table_info} + +注意: +1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 +2. 不要查询不存在的列,注意哪一列位于哪张表中。 +3. 使用 json 格式回答,确保你的回答是必须是正确的 json 格式,并且能被 python 语言的 `json.loads` 库解析, 格式如下: +{response} +""" + +RESPONSE_FORMAT_SIMPLE = { + "thoughts": "对用户说的想法摘要", + "sql": "生成的将被执行的 SQL", +} + +PROMPT_SEP = SeparatorStyle.SINGLE.value + +PROMPT_NEED_NEED_STREAM_OUT = False + +# Temperature is a configuration hyperparameter that controls the randomness of language model output. +# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. +# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. +PROMPT_TEMPERATURE = 0.5 + +prompt = PromptTemplate( + template_scene=ChatScene.ChatWithDbExecute.value(), + input_variables=["input", "table_info", "dialect", "top_k", "response"], + response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), + template_is_strict=False, + template_define=PROMPT_SCENE_DEFINE, + template=_DEFAULT_TEMPLATE, + stream_out=PROMPT_NEED_NEED_STREAM_OUT, + output_parser=DbChatOutputParser( + sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT + ), + # example_selector=sql_data_example, + temperature=PROMPT_TEMPERATURE, +) + +CFG.prompt_template_registry.register( + prompt, + language=CFG.LANGUAGE, + is_default=False, + model_names=["baichuan-13b", "baichuan-7b"], +) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 422fc1117..0ca8f97da 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -19,17 +19,20 @@ class BaseChatAdpter: """Return the generate stream handler func""" pass - def get_conv_template(self) -> Conversation: + def get_conv_template(self, model_path: str) -> Conversation: return None - def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]: + def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]: """Params adaptation""" - conv = self.get_conv_template() + conv = self.get_conv_template(model_path) messages = params.get("messages") # Some model scontext to dbgpt server model_context = {"prompt_echo_len_char": -1} if not conv or not messages: # Nothing to do + print( + f"No conv from model_path {model_path} or no messages in params, {self}" + ) return params, model_context conv = conv.copy() system_messages = [] @@ -84,6 +87,7 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: """Get a chat generate func for a model""" for adapter in llm_model_chat_adapters: if adapter.match(model_path): + print(f"Get model path: {model_path} adapter {adapter}") return adapter raise ValueError(f"Invalid model for chat adapter {model_path}") @@ -191,7 +195,7 @@ class Llama2ChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "llama-2" in model_path.lower() - def get_conv_template(self) -> Conversation: + def get_conv_template(self, model_path: str) -> Conversation: return get_conv_template("llama-2") def get_generate_stream_func(self): @@ -204,8 +208,10 @@ class BaichuanChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "baichuan" in model_path.lower() - def get_conv_template(self) -> Conversation: - return get_conv_template("baichuan-chat") + def get_conv_template(self, model_path: str) -> Conversation: + if "chat" in model_path.lower(): + return get_conv_template("baichuan-chat") + return get_conv_template("zero_shot") def get_generate_stream_func(self): from pilot.model.inference import generate_stream diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 3f97e3f86..9a34f7685 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -78,7 +78,9 @@ class ModelWorker: def generate_stream_gate(self, params): try: # params adaptation - params, model_context = self.llm_chat_adapter.model_adaptation(params) + params, model_context = self.llm_chat_adapter.model_adaptation( + params, self.ml.model_path + ) for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ):