mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
feat: Support llama-2 model
This commit is contained in:
parent
412b104797
commit
168c754a3f
@ -11,7 +11,7 @@ cp .env.template .env
|
||||
LLM_MODEL=vicuna-13b
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
```
|
||||
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b.
|
||||
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.
|
||||
|
||||
if you want use other model, such as chatglm-6b, you just need update .env config file.
|
||||
```
|
||||
|
@ -47,6 +47,9 @@ LLM_MODEL_CONFIG = {
|
||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
||||
"proxyllm": "proxyllm",
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -263,12 +263,26 @@ class ProxyllmAdapter(BaseLLMAdaper):
|
||||
return "proxyllm", None
|
||||
|
||||
|
||||
class Llama2Adapter(BaseLLMAdaper):
|
||||
"""The model adapter for llama-2"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "llama-2" in model_path.lower()
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
||||
register_llm_model_adapters(FalconAdapater)
|
||||
register_llm_model_adapters(GorillaAdapter)
|
||||
register_llm_model_adapters(GPT4AllAdapter)
|
||||
register_llm_model_adapters(Llama2Adapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test_py, remove this later
|
||||
|
308
pilot/model/conversation.py
Normal file
308
pilot/model/conversation.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""
|
||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
Conversation prompt templates.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, IntEnum
|
||||
from typing import List, Any, Dict, Callable
|
||||
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
"""Separator styles."""
|
||||
|
||||
ADD_COLON_SINGLE = auto()
|
||||
ADD_COLON_TWO = auto()
|
||||
ADD_COLON_SPACE_SINGLE = auto()
|
||||
NO_COLON_SINGLE = auto()
|
||||
NO_COLON_TWO = auto()
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
LLAMA2 = auto()
|
||||
CHATGLM = auto()
|
||||
CHATML = auto()
|
||||
CHATINTERN = auto()
|
||||
DOLLY = auto()
|
||||
RWKV = auto()
|
||||
PHOENIX = auto()
|
||||
ROBIN = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that manages prompt templates and keeps all conversation history."""
|
||||
|
||||
# The name of this template
|
||||
name: str
|
||||
# The system prompt
|
||||
system: str
|
||||
# Two roles
|
||||
roles: List[str]
|
||||
# All messages. Each item is (role, message).
|
||||
messages: List[List[str]]
|
||||
# The number of few shot examples
|
||||
offset: int
|
||||
# Separators
|
||||
sep_style: SeparatorStyle
|
||||
sep: str
|
||||
sep2: str = None
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: str = None
|
||||
# Stops generation if meeting any token in this list
|
||||
stop_token_ids: List[int] = None
|
||||
|
||||
# format system message
|
||||
system_formatter: Callable = None
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ": " # must be end with a space
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
ret = "" if self.system == "" else self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + "\n" + message + self.sep
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + message + seps[i % 2]
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.RWKV:
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += (
|
||||
role
|
||||
+ ": "
|
||||
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
|
||||
)
|
||||
ret += "\n\n"
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if i == 0:
|
||||
ret += self.system + message
|
||||
else:
|
||||
ret += role + " " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if self.system:
|
||||
ret = self.system + self.sep
|
||||
else:
|
||||
ret = ""
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
ret += f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
ret += f"{role}:{message}{self.sep}"
|
||||
else:
|
||||
ret += f"{role}:"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.CHATML:
|
||||
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + "\n" + message + self.sep + "\n"
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
ret += "<s>"
|
||||
if message:
|
||||
ret += role + ":" + message + seps[i % 2] + "\n"
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ":\n" + message + seps[i % 2]
|
||||
if i % 2 == 1:
|
||||
ret += "\n\n"
|
||||
else:
|
||||
ret += role + ":\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.PHOENIX:
|
||||
ret = self.system
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + "<s>" + message + "</s>"
|
||||
else:
|
||||
ret += role + ": " + "<s>"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ROBIN:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ":\n" + message + self.sep
|
||||
else:
|
||||
ret += role + ":\n"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
def update_last_message(self, message: str):
|
||||
"""Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
"""
|
||||
self.messages[-1][1] = message
|
||||
|
||||
def update_system_message(self, system_message: str):
|
||||
"""Update system message"""
|
||||
if self.system_formatter:
|
||||
self.system = self.system_formatter(system_message)
|
||||
else:
|
||||
self.system = system_message
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
"""Convert the conversation to gradio chatbot format."""
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def to_openai_api_messages(self):
|
||||
"""Convert the conversation to OpenAI chat completion format."""
|
||||
ret = [{"role": "system", "content": self.system}]
|
||||
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append({"role": "user", "content": msg})
|
||||
else:
|
||||
if msg is not None:
|
||||
ret.append({"role": "assistant", "content": msg})
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
name=self.name,
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
system_formatter=self.system_formatter,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"template_name": self.name,
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
}
|
||||
|
||||
|
||||
# A global registry for all conversation templates
|
||||
conv_templates: Dict[str, Conversation] = {}
|
||||
|
||||
|
||||
def register_conv_template(template: Conversation, override: bool = False):
|
||||
"""Register a new conversation template."""
|
||||
if not override:
|
||||
assert (
|
||||
template.name not in conv_templates
|
||||
), f"{template.name} has been registered."
|
||||
|
||||
conv_templates[template.name] = template
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> Conversation:
|
||||
"""Get a conversation template."""
|
||||
return conv_templates[name].copy()
|
||||
|
||||
|
||||
# llama2 template
|
||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama-2",
|
||||
system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||
roles=("[INST]", "[/INST]"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA2,
|
||||
sep=" ",
|
||||
sep2=" </s><s>",
|
||||
stop_token_ids=[2],
|
||||
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
|
||||
)
|
||||
)
|
||||
|
||||
# TODO Support other model conversation template
|
242
pilot/model/inference.py
Normal file
242
pilot/model/inference.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""
|
||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
|
||||
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import gc
|
||||
from typing import Iterable, Dict
|
||||
|
||||
import torch
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
from pilot.model.llm_utils import is_sentence_complete, is_partial_stop
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||
if temperature >= 1e-5 and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if repetition_penalty > 1.0:
|
||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
if 1e-8 <= top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
if top_k > 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
return processor_list
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
stream_interval: int = 2,
|
||||
judge_sent_end: bool = False,
|
||||
):
|
||||
# Read parameters
|
||||
prompt = params["prompt"]
|
||||
print(f"Prompt of model: \n{prompt}")
|
||||
len_prompt = len(prompt)
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
echo = bool(params.get("echo", True))
|
||||
stop_str = params.get("stop", None)
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
logits_processor = prepare_logits_processor(
|
||||
temperature, repetition_penalty, top_p, top_k
|
||||
)
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
max_src_len = context_len
|
||||
else: # truncate
|
||||
max_src_len = context_len - max_new_tokens - 1
|
||||
|
||||
input_ids = input_ids[-max_src_len:]
|
||||
output_ids = list(input_ids)
|
||||
input_echo_len = len(input_ids)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_output = model.encoder(
|
||||
input_ids=torch.as_tensor([input_ids], device=device)
|
||||
)[0]
|
||||
start_ids = torch.as_tensor(
|
||||
[[model.generation_config.decoder_start_token_id]],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
past_key_values = out = None
|
||||
sent_interrupt = False
|
||||
for i in range(max_new_tokens):
|
||||
if i == 0: # prefill
|
||||
if model.config.is_encoder_decoder:
|
||||
out = model.decoder(
|
||||
input_ids=start_ids,
|
||||
encoder_hidden_states=encoder_output,
|
||||
use_cache=True,
|
||||
)
|
||||
logits = model.lm_head(out[0])
|
||||
else:
|
||||
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
else: # decoding
|
||||
if model.config.is_encoder_decoder:
|
||||
out = model.decoder(
|
||||
input_ids=torch.as_tensor(
|
||||
[[token] if not sent_interrupt else output_ids], device=device
|
||||
),
|
||||
encoder_hidden_states=encoder_output,
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values if not sent_interrupt else None,
|
||||
)
|
||||
sent_interrupt = False
|
||||
|
||||
logits = model.lm_head(out[0])
|
||||
else:
|
||||
out = model(
|
||||
input_ids=torch.as_tensor(
|
||||
[[token] if not sent_interrupt else output_ids], device=device
|
||||
),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values if not sent_interrupt else None,
|
||||
)
|
||||
sent_interrupt = False
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
if logits_processor:
|
||||
if repetition_penalty > 1.0:
|
||||
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
||||
else:
|
||||
tmp_output_ids = None
|
||||
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
||||
else:
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
if device == "mps":
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
last_token_logits = last_token_logits.float().to("cpu")
|
||||
|
||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
tokens = [int(index) for index in indices.tolist()]
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
indices = torch.multinomial(probs, num_samples=2)
|
||||
tokens = [int(token) for token in indices.tolist()]
|
||||
token = tokens[0]
|
||||
output_ids.append(token)
|
||||
|
||||
if token in stop_token_ids:
|
||||
stopped = True
|
||||
else:
|
||||
stopped = False
|
||||
|
||||
# Yield the output tokens
|
||||
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
||||
if echo:
|
||||
tmp_output_ids = output_ids
|
||||
rfind_start = len_prompt
|
||||
else:
|
||||
tmp_output_ids = output_ids[input_echo_len:]
|
||||
rfind_start = 0
|
||||
|
||||
output = tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
||||
if judge_sent_end and stopped and not is_sentence_complete(output):
|
||||
if len(tokens) > 1:
|
||||
token = tokens[1]
|
||||
output_ids[-1] = token
|
||||
else:
|
||||
output_ids.pop()
|
||||
stopped = False
|
||||
sent_interrupt = True
|
||||
|
||||
partially_stopped = False
|
||||
if stop_str:
|
||||
if isinstance(stop_str, str):
|
||||
pos = output.rfind(stop_str, rfind_start)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
else:
|
||||
partially_stopped = is_partial_stop(output, stop_str)
|
||||
elif isinstance(stop_str, Iterable):
|
||||
for each_stop in stop_str:
|
||||
pos = output.rfind(each_stop, rfind_start)
|
||||
if pos != -1:
|
||||
output = output[:pos]
|
||||
stopped = True
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(output, each_stop)
|
||||
if partially_stopped:
|
||||
break
|
||||
else:
|
||||
raise ValueError("Invalid stop field type.")
|
||||
|
||||
# Prevent yielding partial stop sequence
|
||||
if not partially_stopped:
|
||||
yield output
|
||||
# yield {
|
||||
# "text": output,
|
||||
# "usage": {
|
||||
# "prompt_tokens": input_echo_len,
|
||||
# "completion_tokens": i,
|
||||
# "total_tokens": input_echo_len + i,
|
||||
# },
|
||||
# "finish_reason": None,
|
||||
# }
|
||||
|
||||
if stopped:
|
||||
break
|
||||
|
||||
# Finish stream event, which contains finish reason
|
||||
if i == max_new_tokens - 1:
|
||||
finish_reason = "length"
|
||||
elif stopped:
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
finish_reason = None
|
||||
yield output
|
||||
# yield {
|
||||
# "text": output,
|
||||
# "usage": {
|
||||
# "prompt_tokens": input_echo_len,
|
||||
# "completion_tokens": i,
|
||||
# "total_tokens": input_echo_len + i,
|
||||
# },
|
||||
# "finish_reason": finish_reason,
|
||||
# }
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
@ -6,7 +6,7 @@ import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
from pilot.scene.base_message import ModelMessage
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -25,11 +25,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
if message.role == "human":
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == "system":
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
elif message.role == "ai":
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
@ -10,7 +10,6 @@ from typing import List, Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.model.base import Message
|
||||
from pilot.server.llmserver import generate_output
|
||||
|
||||
|
||||
def create_chat_completion(
|
||||
@ -115,3 +114,17 @@ class Iteratorize:
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_now = True
|
||||
|
||||
|
||||
def is_sentence_complete(output: str):
|
||||
"""Check whether the output is a complete sentence."""
|
||||
end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
|
||||
return output.endswith(end_symbols)
|
||||
|
||||
|
||||
def is_partial_stop(output: str, stop_str: str):
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
@ -53,8 +53,15 @@ class BaseOutputParser(ABC):
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
model_context = data.get("model_context")
|
||||
if model_context and "prompt_echo_len_char" in model_context:
|
||||
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
||||
if prompt_echo_len_char != -1:
|
||||
skip_echo_len = prompt_echo_len_char
|
||||
|
||||
if data.get("error_code", 0) == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL:
|
||||
# TODO Judging from model_context
|
||||
# output = data["text"][skip_echo_len + 11:].strip()
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
elif "guanaco" in CFG.LLM_MODEL:
|
||||
|
@ -39,6 +39,7 @@ from pilot.scene.base_message import (
|
||||
AIMessage,
|
||||
ViewMessage,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
|
||||
@ -258,7 +259,7 @@ class BaseChat(ABC):
|
||||
if self.prompt_template.template_define:
|
||||
messages.append(
|
||||
ModelMessage(
|
||||
role="system",
|
||||
role=ModelMessageRoleType.SYSTEM,
|
||||
content=self.prompt_template.template_define,
|
||||
)
|
||||
)
|
||||
|
@ -89,6 +89,14 @@ class ModelMessage(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ModelMessageRoleType:
|
||||
""" "Type of ModelMessage role"""
|
||||
|
||||
SYSTEM = "system"
|
||||
HUMAN = "human"
|
||||
AI = "ai"
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
"""Output of a single generation."""
|
||||
|
||||
|
@ -2,8 +2,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from functools import cache
|
||||
from typing import List
|
||||
from typing import List, Dict, Tuple
|
||||
from pilot.model.llm_out.vicuna_base_llm import generate_stream
|
||||
from pilot.model.conversation import Conversation, get_conv_template
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
|
||||
class BaseChatAdpter:
|
||||
@ -17,6 +19,52 @@ class BaseChatAdpter:
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
return None
|
||||
|
||||
def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
conv = self.get_conv_template()
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1}
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
return params, model_context
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
role, content = None, None
|
||||
if isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {message}")
|
||||
|
||||
if role == ModelMessageRoleType.SYSTEM:
|
||||
# Support for multiple system messages
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
conv.append_message(conv.roles[0], content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
conv.append_message(conv.roles[1], content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
if system_messages:
|
||||
conv.update_system_message("".join(system_messages))
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
# Overwrite the original prompt
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
params["prompt"] = new_prompt
|
||||
return params, model_context
|
||||
|
||||
|
||||
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
||||
|
||||
@ -134,12 +182,26 @@ class GPT4AllChatAdapter(BaseChatAdpter):
|
||||
return gpt4all_generate_stream
|
||||
|
||||
|
||||
class Llama2ChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "llama-2" in model_path.lower()
|
||||
|
||||
def get_conv_template(self) -> Conversation:
|
||||
return get_conv_template("llama-2")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||
register_llm_model_chat_adapter(FalconChatAdapter)
|
||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
@ -77,6 +77,8 @@ class ModelWorker:
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(params)
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
@ -84,10 +86,8 @@ class ModelWorker:
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
}
|
||||
# return some model context to dgt-server
|
||||
ret = {"text": output, "error_code": 0, "model_context": model_context}
|
||||
yield json.dumps(ret).encode() + b"\0"
|
||||
|
||||
except torch.cuda.CudaError:
|
||||
|
Loading…
Reference in New Issue
Block a user