mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-28 21:02:08 +00:00
feat: Support baichuan-7B model
This commit is contained in:
parent
01074660bc
commit
35d0c17ae8
@ -31,6 +31,9 @@ QUANTIZE_QLORA=True
|
|||||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
# SMART_LLM_MODEL=vicuna-13b
|
# SMART_LLM_MODEL=vicuna-13b
|
||||||
# FAST_LLM_MODEL=chatglm-6b
|
# FAST_LLM_MODEL=chatglm-6b
|
||||||
|
## Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm", When we use the rest API provided by deployment frameworks like fastchat as a proxyllm,
|
||||||
|
## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||||
|
# PROXYLLM_BACKEND=
|
||||||
|
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
|
@ -11,7 +11,7 @@ cp .env.template .env
|
|||||||
LLM_MODEL=vicuna-13b
|
LLM_MODEL=vicuna-13b
|
||||||
MODEL_SERVER=http://127.0.0.1:8000
|
MODEL_SERVER=http://127.0.0.1:8000
|
||||||
```
|
```
|
||||||
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.
|
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b, baichuan-7b, baichuan-13b
|
||||||
|
|
||||||
if you want use other model, such as chatglm-6b, you just need update .env config file.
|
if you want use other model, such as chatglm-6b, you just need update .env config file.
|
||||||
```
|
```
|
||||||
|
@ -36,7 +36,19 @@ class StrictFormatter(Formatter):
|
|||||||
super().format(format_string, **dummy_inputs)
|
super().format(format_string, **dummy_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class NoStrictFormatter(StrictFormatter):
|
||||||
|
def check_unused_args(
|
||||||
|
self,
|
||||||
|
used_args: Sequence[Union[int, str]],
|
||||||
|
args: Sequence,
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Not check unused args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
formatter = StrictFormatter()
|
formatter = StrictFormatter()
|
||||||
|
no_strict_formatter = NoStrictFormatter()
|
||||||
|
|
||||||
|
|
||||||
class MyEncoder(json.JSONEncoder):
|
class MyEncoder(json.JSONEncoder):
|
||||||
|
@ -131,6 +131,13 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
### LLM Model Service Configuration
|
### LLM Model Service Configuration
|
||||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||||
|
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||||
|
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||||
|
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||||
|
self.PROXYLLM_BACKEND = None
|
||||||
|
if self.LLM_MODEL == "proxyllm":
|
||||||
|
self.PROXYLLM_BACKEND = os.getenv("PROXYLLM_BACKEND")
|
||||||
|
|
||||||
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
self.LIMIT_MODEL_CONCURRENCY = int(os.getenv("LIMIT_MODEL_CONCURRENCY", 5))
|
||||||
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
self.MAX_POSITION_EMBEDDINGS = int(os.getenv("MAX_POSITION_EMBEDDINGS", 4096))
|
||||||
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
self.MODEL_PORT = os.getenv("MODEL_PORT", 8000)
|
||||||
|
@ -51,6 +51,8 @@ LLM_MODEL_CONFIG = {
|
|||||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||||
|
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||||
|
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
|
@ -12,8 +12,6 @@ from transformers import (
|
|||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
)
|
)
|
||||||
from transformers.generation.utils import GenerationConfig
|
|
||||||
|
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
|
|
||||||
@ -285,8 +283,9 @@ class BaichuanAdapter(BaseLLMAdaper):
|
|||||||
return "baichuan" in model_path.lower()
|
return "baichuan" in model_path.lower()
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
# revision = from_pretrained_kwargs.get("revision", "main")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
model_path, trust_remote_code=True, use_fast=False
|
||||||
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
|
|
||||||
Conversation prompt templates.
|
Conversation prompt templates.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
@ -286,6 +284,21 @@ def get_conv_template(name: str) -> Conversation:
|
|||||||
return conv_templates[name].copy()
|
return conv_templates[name].copy()
|
||||||
|
|
||||||
|
|
||||||
|
# A template similar to the "one_shot" template above but remove the example.
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="zero_shot",
|
||||||
|
system="A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||||
|
roles=("Human", "Assistant"),
|
||||||
|
messages=(),
|
||||||
|
offset=0,
|
||||||
|
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
||||||
|
sep="\n### ",
|
||||||
|
stop_str="###",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# llama2 template
|
# llama2 template
|
||||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
|
||||||
from pilot.common.formatting import formatter
|
from pilot.common.formatting import formatter, no_strict_formatter
|
||||||
from pilot.out_parser.base import BaseOutputParser
|
from pilot.out_parser.base import BaseOutputParser
|
||||||
from pilot.common.schema import SeparatorStyle
|
from pilot.common.schema import SeparatorStyle
|
||||||
from pilot.prompts.example_base import ExampleSelector
|
from pilot.prompts.example_base import ExampleSelector
|
||||||
@ -24,8 +24,10 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||||
"f-string": formatter.format,
|
"f-string": lambda is_strict: formatter.format
|
||||||
"jinja2": jinja2_formatter,
|
if is_strict
|
||||||
|
else no_strict_formatter.format,
|
||||||
|
"jinja2": lambda is_strict: jinja2_formatter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -38,6 +40,8 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
template: Optional[str]
|
template: Optional[str]
|
||||||
"""The prompt template."""
|
"""The prompt template."""
|
||||||
template_format: str = "f-string"
|
template_format: str = "f-string"
|
||||||
|
"""strict template will check template args"""
|
||||||
|
template_is_strict: bool = True
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
response_format: Optional[str]
|
response_format: Optional[str]
|
||||||
"""default use stream out"""
|
"""default use stream out"""
|
||||||
@ -68,10 +72,12 @@ class PromptTemplate(BaseModel, ABC):
|
|||||||
"""Format the prompt with the inputs."""
|
"""Format the prompt with the inputs."""
|
||||||
if self.template:
|
if self.template:
|
||||||
if self.response_format:
|
if self.response_format:
|
||||||
kwargs["response"] = json.dumps(self.response_format, indent=4)
|
kwargs["response"] = json.dumps(
|
||||||
|
self.response_format, ensure_ascii=False, indent=4
|
||||||
|
)
|
||||||
return DEFAULT_FORMATTER_MAPPING[self.template_format](
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||||
self.template, **kwargs
|
self.template_is_strict
|
||||||
)
|
)(self.template, **kwargs)
|
||||||
|
|
||||||
def add_goals(self, goal: str) -> None:
|
def add_goals(self, goal: str) -> None:
|
||||||
self.goals.append(goal)
|
self.goals.append(goal)
|
||||||
|
@ -58,13 +58,26 @@ class PromptTemplateRegistry:
|
|||||||
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
|
scene_registry, prompt_template, language, [_DEFAULT_MODEL_KEY]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_prompt_template(self, scene_name: str, language: str, model_name: str):
|
def get_prompt_template(
|
||||||
"""Get prompt template with scene name, language and model name"""
|
self,
|
||||||
|
scene_name: str,
|
||||||
|
language: str,
|
||||||
|
model_name: str,
|
||||||
|
proxyllm_backend: str = None,
|
||||||
|
):
|
||||||
|
"""Get prompt template with scene name, language and model name
|
||||||
|
proxyllm_backend: see CFG.PROXYLLM_BACKEND
|
||||||
|
"""
|
||||||
scene_registry = self.registry[scene_name]
|
scene_registry = self.registry[scene_name]
|
||||||
registry = scene_registry.get(model_name)
|
|
||||||
print(
|
print(
|
||||||
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name} language: {language}"
|
f"Get prompt template of scene_name: {scene_name} with model_name: {model_name}, proxyllm_backend: {proxyllm_backend}, language: {language}"
|
||||||
)
|
)
|
||||||
|
registry = None
|
||||||
|
if proxyllm_backend:
|
||||||
|
registry = scene_registry.get(proxyllm_backend)
|
||||||
|
if not registry:
|
||||||
|
registry = scene_registry.get(model_name)
|
||||||
if not registry:
|
if not registry:
|
||||||
registry = scene_registry.get(_DEFAULT_MODEL_KEY)
|
registry = scene_registry.get(_DEFAULT_MODEL_KEY)
|
||||||
if not registry:
|
if not registry:
|
||||||
|
@ -79,7 +79,10 @@ class BaseChat(ABC):
|
|||||||
# ]
|
# ]
|
||||||
self.prompt_template: PromptTemplate = (
|
self.prompt_template: PromptTemplate = (
|
||||||
CFG.prompt_template_registry.get_prompt_template(
|
CFG.prompt_template_registry.get_prompt_template(
|
||||||
self.chat_mode.value(), language=CFG.LANGUAGE, model_name=CFG.LLM_MODEL
|
self.chat_mode.value(),
|
||||||
|
language=CFG.LANGUAGE,
|
||||||
|
model_name=CFG.LLM_MODEL,
|
||||||
|
proxyllm_backend=CFG.PROXYLLM_BACKEND,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||||
|
@ -9,10 +9,7 @@ EXAMPLES = [
|
|||||||
{
|
{
|
||||||
"type": "ai",
|
"type": "ai",
|
||||||
"data": {
|
"data": {
|
||||||
"content": """{
|
"content": """{\n\"thoughts\": \"直接查询用户表中用户名为'test1'的记录即可\",\n\"sql\": \"SELECT city FROM user where user_name='test1'\"}""",
|
||||||
\"thoughts\": \"thought text\",
|
|
||||||
\"sql\": \"SELECT city FROM user where user_name='test1'\",
|
|
||||||
}""",
|
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -24,10 +21,7 @@ EXAMPLES = [
|
|||||||
{
|
{
|
||||||
"type": "ai",
|
"type": "ai",
|
||||||
"data": {
|
"data": {
|
||||||
"content": """{
|
"content": """{\n\"thoughts\": \"根据订单表的用户名和用户表的用户名关联用户表和订单表,再通过用户表的城市为'成都'的过滤即可\",\n\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\"}""",
|
||||||
\"thoughts\": \"thought text\",
|
|
||||||
\"sql\": \"SELECT b.* FROM user a LEFT JOIN tran_order b ON a.user_name=b.user_name where a.city='成都'\",
|
|
||||||
}""",
|
|
||||||
"example": True,
|
"example": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -43,7 +43,7 @@ PROMPT_TEMPERATURE = 0.5
|
|||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
@ -54,3 +54,4 @@ prompt = PromptTemplate(
|
|||||||
temperature=PROMPT_TEMPERATURE,
|
temperature=PROMPT_TEMPERATURE,
|
||||||
)
|
)
|
||||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
from . import prompt_baichuan
|
||||||
|
66
pilot/scene/chat_db/auto_execute/prompt_baichuan.py
Normal file
66
pilot/scene/chat_db/auto_execute/prompt_baichuan.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pilot.prompts.prompt_new import PromptTemplate
|
||||||
|
from pilot.configs.config import Config
|
||||||
|
from pilot.scene.base import ChatScene
|
||||||
|
from pilot.scene.chat_db.auto_execute.out_parser import DbChatOutputParser, SqlAction
|
||||||
|
from pilot.common.schema import SeparatorStyle
|
||||||
|
from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = None
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = """
|
||||||
|
你是一个 SQL 专家,给你一个用户的问题,你会生成一条对应的 {dialect} 语法的 SQL 语句。
|
||||||
|
|
||||||
|
如果用户没有在问题中指定 sql 返回多少条数据,那么你生成的 sql 最多返回 {top_k} 条数据。
|
||||||
|
你应该尽可能少地使用表。
|
||||||
|
|
||||||
|
已知表结构信息如下:
|
||||||
|
{table_info}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||||
|
2. 不要查询不存在的列,注意哪一列位于哪张表中。
|
||||||
|
3. 使用 json 格式回答,确保你的回答是必须是正确的 json 格式,并且能被 python 语言的 `json.loads` 库解析, 格式如下:
|
||||||
|
{response}
|
||||||
|
"""
|
||||||
|
|
||||||
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
|
"thoughts": "对用户说的想法摘要",
|
||||||
|
"sql": "生成的将被执行的 SQL",
|
||||||
|
}
|
||||||
|
|
||||||
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
|
# For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0.
|
||||||
|
PROMPT_TEMPERATURE = 0.5
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||||
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
|
template_is_strict=False,
|
||||||
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
|
template=_DEFAULT_TEMPLATE,
|
||||||
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
output_parser=DbChatOutputParser(
|
||||||
|
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
|
||||||
|
),
|
||||||
|
# example_selector=sql_data_example,
|
||||||
|
temperature=PROMPT_TEMPERATURE,
|
||||||
|
)
|
||||||
|
|
||||||
|
CFG.prompt_template_registry.register(
|
||||||
|
prompt,
|
||||||
|
language=CFG.LANGUAGE,
|
||||||
|
is_default=False,
|
||||||
|
model_names=["baichuan-13b", "baichuan-7b"],
|
||||||
|
)
|
@ -19,17 +19,20 @@ class BaseChatAdpter:
|
|||||||
"""Return the generate stream handler func"""
|
"""Return the generate stream handler func"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_conv_template(self) -> Conversation:
|
def get_conv_template(self, model_path: str) -> Conversation:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]:
|
def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
|
||||||
"""Params adaptation"""
|
"""Params adaptation"""
|
||||||
conv = self.get_conv_template()
|
conv = self.get_conv_template(model_path)
|
||||||
messages = params.get("messages")
|
messages = params.get("messages")
|
||||||
# Some model scontext to dbgpt server
|
# Some model scontext to dbgpt server
|
||||||
model_context = {"prompt_echo_len_char": -1}
|
model_context = {"prompt_echo_len_char": -1}
|
||||||
if not conv or not messages:
|
if not conv or not messages:
|
||||||
# Nothing to do
|
# Nothing to do
|
||||||
|
print(
|
||||||
|
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||||
|
)
|
||||||
return params, model_context
|
return params, model_context
|
||||||
conv = conv.copy()
|
conv = conv.copy()
|
||||||
system_messages = []
|
system_messages = []
|
||||||
@ -84,6 +87,7 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
|||||||
"""Get a chat generate func for a model"""
|
"""Get a chat generate func for a model"""
|
||||||
for adapter in llm_model_chat_adapters:
|
for adapter in llm_model_chat_adapters:
|
||||||
if adapter.match(model_path):
|
if adapter.match(model_path):
|
||||||
|
print(f"Get model path: {model_path} adapter {adapter}")
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||||
@ -191,7 +195,7 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
|||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "llama-2" in model_path.lower()
|
return "llama-2" in model_path.lower()
|
||||||
|
|
||||||
def get_conv_template(self) -> Conversation:
|
def get_conv_template(self, model_path: str) -> Conversation:
|
||||||
return get_conv_template("llama-2")
|
return get_conv_template("llama-2")
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
@ -204,8 +208,10 @@ class BaichuanChatAdapter(BaseChatAdpter):
|
|||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "baichuan" in model_path.lower()
|
return "baichuan" in model_path.lower()
|
||||||
|
|
||||||
def get_conv_template(self) -> Conversation:
|
def get_conv_template(self, model_path: str) -> Conversation:
|
||||||
return get_conv_template("baichuan-chat")
|
if "chat" in model_path.lower():
|
||||||
|
return get_conv_template("baichuan-chat")
|
||||||
|
return get_conv_template("zero_shot")
|
||||||
|
|
||||||
def get_generate_stream_func(self):
|
def get_generate_stream_func(self):
|
||||||
from pilot.model.inference import generate_stream
|
from pilot.model.inference import generate_stream
|
||||||
|
@ -78,7 +78,9 @@ class ModelWorker:
|
|||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
try:
|
try:
|
||||||
# params adaptation
|
# params adaptation
|
||||||
params, model_context = self.llm_chat_adapter.model_adaptation(params)
|
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||||
|
params, self.ml.model_path
|
||||||
|
)
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user