From ef2f7999a55f328e99a810b4fc0afa17bbd7f165 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Thu, 20 Jul 2023 21:43:20 +0800 Subject: [PATCH] feat: Support llama-2 model --- docs/modules/llms.md | 2 +- pilot/configs/model_config.py | 3 + pilot/model/adapter.py | 14 ++ pilot/model/conversation.py | 308 +++++++++++++++++++++++++++++++ pilot/model/inference.py | 242 ++++++++++++++++++++++++ pilot/model/llm_out/proxy_llm.py | 8 +- pilot/model/llm_utils.py | 15 +- pilot/out_parser/base.py | 9 +- pilot/scene/base_chat.py | 3 +- pilot/scene/base_message.py | 8 + pilot/server/chat_adapter.py | 64 ++++++- pilot/server/llmserver.py | 8 +- 12 files changed, 671 insertions(+), 13 deletions(-) create mode 100644 pilot/model/conversation.py create mode 100644 pilot/model/inference.py diff --git a/docs/modules/llms.md b/docs/modules/llms.md index bec64313e..a4baf1807 100644 --- a/docs/modules/llms.md +++ b/docs/modules/llms.md @@ -11,7 +11,7 @@ cp .env.template .env LLM_MODEL=vicuna-13b MODEL_SERVER=http://127.0.0.1:8000 ``` -now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b. +now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b. if you want use other model, such as chatglm-6b, you just need update .env config file. ``` diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 22354b4bf..aea5c6731 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -47,6 +47,9 @@ LLM_MODEL_CONFIG = { "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", + "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), + "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), + "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), } # Load model config diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 2d420c02f..cea73c602 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -263,12 +263,26 @@ class ProxyllmAdapter(BaseLLMAdaper): return "proxyllm", None +class Llama2Adapter(BaseLLMAdaper): + """The model adapter for llama-2""" + + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().loader(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GPT4AllAdapter) +register_llm_model_adapters(Llama2Adapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test_py, remove this later diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py new file mode 100644 index 000000000..5443992df --- /dev/null +++ b/pilot/model/conversation.py @@ -0,0 +1,308 @@ +""" +Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py + +Conversation prompt templates. +""" + +import dataclasses +from enum import auto, IntEnum +from typing import List, Any, Dict, Callable + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The system prompt + system: str + # Two roles + roles: List[str] + # All messages. Each item is (role, message). + messages: List[List[str]] + # The number of few shot examples + offset: int + # Separators + sep_style: SeparatorStyle + sep: str + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + # format system message + system_formatter: Callable = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if self.system == "" else self.system + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if i == 0: + ret += self.system + message + else: + ret += role + " " + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if self.system: + ret = self.system + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if self.system == "" else self.system + self.sep + "\n" + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def update_system_message(self, system_message: str): + """Update system message""" + if self.system_formatter: + self.system = self.system_formatter(system_message) + else: + self.system = system_message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{"role": "system", "content": self.system}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + system_formatter=self.system_formatter, + ) + + def dict(self): + return { + "template_name": self.name, + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# llama2 template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information.\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + system_formatter=lambda msg: f"[INST] <>\n{msg}\n<>\n\n", + ) +) + +# TODO Support other model conversation template diff --git a/pilot/model/inference.py b/pilot/model/inference.py new file mode 100644 index 000000000..cbbd04339 --- /dev/null +++ b/pilot/model/inference.py @@ -0,0 +1,242 @@ +""" +Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py + +""" +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import gc +from typing import Iterable, Dict + +import torch + +import torch + +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from pilot.model.llm_utils import is_sentence_complete, is_partial_stop + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list + + +@torch.inference_mode() +def generate_stream( + model, + tokenizer, + params: Dict, + device: str, + context_len: int, + stream_interval: int = 2, + judge_sent_end: bool = False, +): + # Read parameters + prompt = params["prompt"] + print(f"Prompt of model: \n{prompt}") + len_prompt = len(prompt) + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 2048)) + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + input_ids = tokenizer(prompt).input_ids + + if model.config.is_encoder_decoder: + max_src_len = context_len + else: # truncate + max_src_len = context_len - max_new_tokens - 1 + + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + + past_key_values = out = None + sent_interrupt = False + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [[token] if not sent_interrupt else output_ids], device=device + ), + use_cache=True, + past_key_values=past_key_values if not sent_interrupt else None, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + token = tokens[0] + output_ids.append(token) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=True, + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way + if judge_sent_end and stopped and not is_sentence_complete(output): + if len(tokens) > 1: + token = tokens[1] + output_ids[-1] = token + else: + output_ids.pop() + stopped = False + sent_interrupt = True + + partially_stopped = False + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield output + # yield { + # "text": output, + # "usage": { + # "prompt_tokens": input_echo_len, + # "completion_tokens": i, + # "total_tokens": input_echo_len + i, + # }, + # "finish_reason": None, + # } + + if stopped: + break + + # Finish stream event, which contains finish reason + if i == max_new_tokens - 1: + finish_reason = "length" + elif stopped: + finish_reason = "stop" + else: + finish_reason = None + yield output + # yield { + # "text": output, + # "usage": { + # "prompt_tokens": input_echo_len, + # "completion_tokens": i, + # "total_tokens": input_echo_len + i, + # }, + # "finish_reason": finish_reason, + # } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index c353426d2..79cb28adf 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -6,7 +6,7 @@ import requests from typing import List from pilot.configs.config import Config from pilot.conversation import ROLE_ASSISTANT, ROLE_USER -from pilot.scene.base_message import ModelMessage +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType CFG = Config() @@ -25,11 +25,11 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) messages: List[ModelMessage] = params["messages"] # Add history conversation for message in messages: - if message.role == "human": + if message.role == ModelMessageRoleType.HUMAN: history.append({"role": "user", "content": message.content}) - elif message.role == "system": + elif message.role == ModelMessageRoleType.SYSTEM: history.append({"role": "system", "content": message.content}) - elif message.role == "ai": + elif message.role == ModelMessageRoleType.AI: history.append({"role": "assistant", "content": message.content}) else: pass diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 0dfd8d2b5..ec50a7d34 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -10,7 +10,6 @@ from typing import List, Optional from pilot.configs.config import Config from pilot.model.base import Message -from pilot.server.llmserver import generate_output def create_chat_completion( @@ -115,3 +114,17 @@ class Iteratorize: def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True + + +def is_sentence_complete(output: str): + """Check whether the output is a complete sentence.""" + end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") + return output.endswith(end_symbols) + + +def is_partial_stop(output: str, stop_str: str): + """Check whether the output contains a partial stop str.""" + for i in range(0, min(len(output), len(stop_str))): + if stop_str.startswith(output[-i:]): + return True + return False diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 058de71ab..000d92709 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -53,8 +53,15 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ + model_context = data.get("model_context") + if model_context and "prompt_echo_len_char" in model_context: + prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1)) + if prompt_echo_len_char != -1: + skip_echo_len = prompt_echo_len_char + if data.get("error_code", 0) == 0: - if "vicuna" in CFG.LLM_MODEL: + if "vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL: + # TODO Judging from model_context # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() elif "guanaco" in CFG.LLM_MODEL: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index bd660a0cd..e6b9bb9f4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -39,6 +39,7 @@ from pilot.scene.base_message import ( AIMessage, ViewMessage, ModelMessage, + ModelMessageRoleType, ) from pilot.configs.config import Config @@ -258,7 +259,7 @@ class BaseChat(ABC): if self.prompt_template.template_define: messages.append( ModelMessage( - role="system", + role=ModelMessageRoleType.SYSTEM, content=self.prompt_template.template_define, ) ) diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 20d513c39..09ea9695d 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -89,6 +89,14 @@ class ModelMessage(BaseModel): content: str +class ModelMessageRoleType: + """ "Type of ModelMessage role""" + + SYSTEM = "system" + HUMAN = "human" + AI = "ai" + + class Generation(BaseModel): """Output of a single generation.""" diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index ebab2d2d4..1c0d73bac 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -2,8 +2,10 @@ # -*- coding: utf-8 -*- from functools import cache -from typing import List +from typing import List, Dict, Tuple from pilot.model.llm_out.vicuna_base_llm import generate_stream +from pilot.model.conversation import Conversation, get_conv_template +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType class BaseChatAdpter: @@ -17,6 +19,52 @@ class BaseChatAdpter: """Return the generate stream handler func""" pass + def get_conv_template(self) -> Conversation: + return None + + def model_adaptation(self, params: Dict) -> Tuple[Dict, Dict]: + """Params adaptation""" + conv = self.get_conv_template() + messages = params.get("messages") + # Some model scontext to dbgpt server + model_context = {"prompt_echo_len_char": -1} + if not conv or not messages: + # Nothing to do + return params, model_context + conv = conv.copy() + system_messages = [] + for message in messages: + role, content = None, None + if isinstance(message, ModelMessage): + role = message.role + content = message.content + elif isinstance(message, dict): + role = message["role"] + content = message["content"] + else: + raise ValueError(f"Invalid message type: {message}") + + if role == ModelMessageRoleType.SYSTEM: + # Support for multiple system messages + system_messages.append(content) + elif role == ModelMessageRoleType.HUMAN: + conv.append_message(conv.roles[0], content) + elif role == ModelMessageRoleType.AI: + conv.append_message(conv.roles[1], content) + else: + raise ValueError(f"Unknown role: {role}") + if system_messages: + conv.update_system_message("".join(system_messages)) + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + new_prompt = conv.get_prompt() + # Overwrite the original prompt + # TODO remote bos token and eos token from tokenizer_config.json of model + prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) + model_context["prompt_echo_len_char"] = prompt_echo_len_char + params["prompt"] = new_prompt + return params, model_context + llm_model_chat_adapters: List[BaseChatAdpter] = [] @@ -134,12 +182,26 @@ class GPT4AllChatAdapter(BaseChatAdpter): return gpt4all_generate_stream +class Llama2ChatAdapter(BaseChatAdpter): + def match(self, model_path: str): + return "llama-2" in model_path.lower() + + def get_conv_template(self) -> Conversation: + return get_conv_template("llama-2") + + def get_generate_stream_func(self): + from pilot.model.inference import generate_stream + + return generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GPT4AllChatAdapter) +register_llm_model_chat_adapter(Llama2ChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 910a97573..43f5a1f94 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -77,6 +77,8 @@ class ModelWorker: def generate_stream_gate(self, params): try: + # params adaptation + params, model_context = self.llm_chat_adapter.model_adaptation(params) for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): @@ -84,10 +86,8 @@ class ModelWorker: # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. print("output: ", output) - ret = { - "text": output, - "error_code": 0, - } + # return some model context to dgt-server + ret = {"text": output, "error_code": 0, "model_context": model_context} yield json.dumps(ret).encode() + b"\0" except torch.cuda.CudaError: