mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-31 06:39:43 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			412 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			412 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | ||
| Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
 | ||
| 
 | ||
| Conversation prompt templates.
 | ||
| 
 | ||
| 
 | ||
| This code file will be deprecated in the future. 
 | ||
| We have integrated fastchat. For details, see: pilot/model/model_adapter.py
 | ||
| """
 | ||
| 
 | ||
| import dataclasses
 | ||
| from enum import auto, IntEnum
 | ||
| from typing import List, Dict, Callable
 | ||
| 
 | ||
| 
 | ||
| class SeparatorStyle(IntEnum):
 | ||
|     """Separator styles."""
 | ||
| 
 | ||
|     ADD_COLON_SINGLE = auto()
 | ||
|     ADD_COLON_TWO = auto()
 | ||
|     ADD_COLON_SPACE_SINGLE = auto()
 | ||
|     NO_COLON_SINGLE = auto()
 | ||
|     NO_COLON_TWO = auto()
 | ||
|     ADD_NEW_LINE_SINGLE = auto()
 | ||
|     LLAMA2 = auto()
 | ||
|     CHATGLM = auto()
 | ||
|     CHATML = auto()
 | ||
|     CHATINTERN = auto()
 | ||
|     DOLLY = auto()
 | ||
|     RWKV = auto()
 | ||
|     PHOENIX = auto()
 | ||
|     ROBIN = auto()
 | ||
| 
 | ||
| 
 | ||
| @dataclasses.dataclass
 | ||
| class Conversation:
 | ||
|     """A class that manages prompt templates and keeps all conversation history."""
 | ||
| 
 | ||
|     # The name of this template
 | ||
|     name: str
 | ||
|     # The system prompt
 | ||
|     system: str
 | ||
|     # Two roles
 | ||
|     roles: List[str]
 | ||
|     # All messages. Each item is (role, message).
 | ||
|     messages: List[List[str]]
 | ||
|     # The number of few shot examples
 | ||
|     offset: int
 | ||
|     # Separators
 | ||
|     sep_style: SeparatorStyle
 | ||
|     sep: str
 | ||
|     sep2: str = None
 | ||
|     # Stop criteria (the default one is EOS token)
 | ||
|     stop_str: str = None
 | ||
|     # Stops generation if meeting any token in this list
 | ||
|     stop_token_ids: List[int] = None
 | ||
| 
 | ||
|     # format system message
 | ||
|     system_formatter: Callable = None
 | ||
| 
 | ||
|     def get_prompt(self) -> str:
 | ||
|         """Get the prompt for generation."""
 | ||
|         if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
 | ||
|             ret = self.system + self.sep
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + ": " + message + self.sep
 | ||
|                 else:
 | ||
|                     ret += role + ":"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
 | ||
|             seps = [self.sep, self.sep2]
 | ||
|             ret = self.system + seps[0]
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if message:
 | ||
|                     ret += role + ": " + message + seps[i % 2]
 | ||
|                 else:
 | ||
|                     ret += role + ":"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
 | ||
|             ret = self.system + self.sep
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + ": " + message + self.sep
 | ||
|                 else:
 | ||
|                     ret += role + ": "  # must be end with a space
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
 | ||
|             ret = "" if self.system == "" else self.system + self.sep
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + "\n" + message + self.sep
 | ||
|                 else:
 | ||
|                     ret += role + "\n"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
 | ||
|             ret = self.system
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + message + self.sep
 | ||
|                 else:
 | ||
|                     ret += role
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
 | ||
|             seps = [self.sep, self.sep2]
 | ||
|             ret = self.system
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if message:
 | ||
|                     ret += role + message + seps[i % 2]
 | ||
|                 else:
 | ||
|                     ret += role
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.RWKV:
 | ||
|             ret = self.system
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if message:
 | ||
|                     ret += (
 | ||
|                         role
 | ||
|                         + ": "
 | ||
|                         + message.replace("\r\n", "\n").replace("\n\n", "\n")
 | ||
|                     )
 | ||
|                     ret += "\n\n"
 | ||
|                 else:
 | ||
|                     ret += role + ":"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.LLAMA2:
 | ||
|             seps = [self.sep, self.sep2]
 | ||
|             ret = ""
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if message:
 | ||
|                     if i == 0:
 | ||
|                         ret += self.system + message
 | ||
|                     else:
 | ||
|                         ret += role + " " + message + seps[i % 2]
 | ||
|                 else:
 | ||
|                     ret += role
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.CHATGLM:
 | ||
|             # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
 | ||
|             # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
 | ||
|             round_add_n = 1 if self.name == "chatglm2" else 0
 | ||
|             if self.system:
 | ||
|                 ret = self.system + self.sep
 | ||
|             else:
 | ||
|                 ret = ""
 | ||
| 
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if i % 2 == 0:
 | ||
|                     ret += f"[Round {i//2 + round_add_n}]{self.sep}"
 | ||
| 
 | ||
|                 if message:
 | ||
|                     ret += f"{role}:{message}{self.sep}"
 | ||
|                 else:
 | ||
|                     ret += f"{role}:"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.CHATML:
 | ||
|             ret = "" if self.system == "" else self.system + self.sep + "\n"
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + "\n" + message + self.sep + "\n"
 | ||
|                 else:
 | ||
|                     ret += role + "\n"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.CHATINTERN:
 | ||
|             # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
 | ||
|             seps = [self.sep, self.sep2]
 | ||
|             ret = self.system
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if i % 2 == 0:
 | ||
|                     ret += "<s>"
 | ||
|                 if message:
 | ||
|                     ret += role + ":" + message + seps[i % 2] + "\n"
 | ||
|                 else:
 | ||
|                     ret += role + ":"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.DOLLY:
 | ||
|             seps = [self.sep, self.sep2]
 | ||
|             ret = self.system
 | ||
|             for i, (role, message) in enumerate(self.messages):
 | ||
|                 if message:
 | ||
|                     ret += role + ":\n" + message + seps[i % 2]
 | ||
|                     if i % 2 == 1:
 | ||
|                         ret += "\n\n"
 | ||
|                 else:
 | ||
|                     ret += role + ":\n"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.PHOENIX:
 | ||
|             ret = self.system
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + ": " + "<s>" + message + "</s>"
 | ||
|                 else:
 | ||
|                     ret += role + ": " + "<s>"
 | ||
|             return ret
 | ||
|         elif self.sep_style == SeparatorStyle.ROBIN:
 | ||
|             ret = self.system + self.sep
 | ||
|             for role, message in self.messages:
 | ||
|                 if message:
 | ||
|                     ret += role + ":\n" + message + self.sep
 | ||
|                 else:
 | ||
|                     ret += role + ":\n"
 | ||
|             return ret
 | ||
|         else:
 | ||
|             raise ValueError(f"Invalid style: {self.sep_style}")
 | ||
| 
 | ||
|     def append_message(self, role: str, message: str):
 | ||
|         """Append a new message."""
 | ||
|         self.messages.append([role, message])
 | ||
| 
 | ||
|     def update_last_message(self, message: str):
 | ||
|         """Update the last output.
 | ||
| 
 | ||
|         The last message is typically set to be None when constructing the prompt,
 | ||
|         so we need to update it in-place after getting the response from a model.
 | ||
|         """
 | ||
|         self.messages[-1][1] = message
 | ||
| 
 | ||
|     def update_system_message(self, system_message: str):
 | ||
|         """Update system message"""
 | ||
|         if self.system_formatter:
 | ||
|             self.system = self.system_formatter(system_message)
 | ||
|         else:
 | ||
|             self.system = system_message
 | ||
| 
 | ||
|     def to_gradio_chatbot(self):
 | ||
|         """Convert the conversation to gradio chatbot format."""
 | ||
|         ret = []
 | ||
|         for i, (role, msg) in enumerate(self.messages[self.offset :]):
 | ||
|             if i % 2 == 0:
 | ||
|                 ret.append([msg, None])
 | ||
|             else:
 | ||
|                 ret[-1][-1] = msg
 | ||
|         return ret
 | ||
| 
 | ||
|     def to_openai_api_messages(self):
 | ||
|         """Convert the conversation to OpenAI chat completion format."""
 | ||
|         ret = [{"role": "system", "content": self.system}]
 | ||
| 
 | ||
|         for i, (_, msg) in enumerate(self.messages[self.offset :]):
 | ||
|             if i % 2 == 0:
 | ||
|                 ret.append({"role": "user", "content": msg})
 | ||
|             else:
 | ||
|                 if msg is not None:
 | ||
|                     ret.append({"role": "assistant", "content": msg})
 | ||
|         return ret
 | ||
| 
 | ||
|     def copy(self):
 | ||
|         return Conversation(
 | ||
|             name=self.name,
 | ||
|             system=self.system,
 | ||
|             roles=self.roles,
 | ||
|             messages=[[x, y] for x, y in self.messages],
 | ||
|             offset=self.offset,
 | ||
|             sep_style=self.sep_style,
 | ||
|             sep=self.sep,
 | ||
|             sep2=self.sep2,
 | ||
|             stop_str=self.stop_str,
 | ||
|             stop_token_ids=self.stop_token_ids,
 | ||
|             system_formatter=self.system_formatter,
 | ||
|         )
 | ||
| 
 | ||
|     def dict(self):
 | ||
|         return {
 | ||
|             "template_name": self.name,
 | ||
|             "system": self.system,
 | ||
|             "roles": self.roles,
 | ||
|             "messages": self.messages,
 | ||
|             "offset": self.offset,
 | ||
|         }
 | ||
| 
 | ||
| 
 | ||
| # A global registry for all conversation templates
 | ||
| conv_templates: Dict[str, Conversation] = {}
 | ||
| 
 | ||
| 
 | ||
| def register_conv_template(template: Conversation, override: bool = False):
 | ||
|     """Register a new conversation template."""
 | ||
|     if not override:
 | ||
|         assert (
 | ||
|             template.name not in conv_templates
 | ||
|         ), f"{template.name} has been registered."
 | ||
| 
 | ||
|     conv_templates[template.name] = template
 | ||
| 
 | ||
| 
 | ||
| def get_conv_template(name: str) -> Conversation:
 | ||
|     """Get a conversation template."""
 | ||
|     return conv_templates[name].copy()
 | ||
| 
 | ||
| 
 | ||
| # A template similar to the "one_shot" template above but remove the example.
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="zero_shot",
 | ||
|         system="A chat between a curious human and an artificial intelligence assistant. "
 | ||
|         "The assistant gives helpful, detailed, and polite answers to the human's questions.",
 | ||
|         roles=("Human", "Assistant"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.ADD_COLON_SINGLE,
 | ||
|         sep="\n### ",
 | ||
|         stop_str="###",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| # Vicuna v1.1 template
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="vicuna_v1.1",
 | ||
|         system="A chat between a curious user and an artificial intelligence assistant. "
 | ||
|         "The assistant gives helpful, detailed, and polite answers to the user's questions.",
 | ||
|         roles=("USER", "ASSISTANT"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.ADD_COLON_TWO,
 | ||
|         sep=" ",
 | ||
|         sep2="</s>",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| # llama2 template
 | ||
| # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="llama-2",
 | ||
|         system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
 | ||
|         "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
 | ||
|         "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
 | ||
|         "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
 | ||
|         "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
 | ||
|         roles=("[INST]", "[/INST]"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.LLAMA2,
 | ||
|         sep=" ",
 | ||
|         sep2=" </s><s>",
 | ||
|         stop_token_ids=[2],
 | ||
|         system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| 
 | ||
| # codellama template
 | ||
| # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
 | ||
| # reference2 : https://github.com/eosphoros-ai/DB-GPT-Hub/blob/main/README.zh.md
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="codellama",
 | ||
|         system="<s>[INST] <<SYS>>\nI want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request."
 | ||
|         "If you don't know the answer to the request, please don't share false information.\n<</SYS>>\n\n",
 | ||
|         roles=("[INST]", "[/INST]"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.LLAMA2,
 | ||
|         sep=" ",
 | ||
|         sep2=" </s><s>",
 | ||
|         stop_token_ids=[2],
 | ||
|         system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| 
 | ||
| # Alpaca default template
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="alpaca",
 | ||
|         system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
 | ||
|         roles=("### Instruction", "### Response"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.ADD_COLON_TWO,
 | ||
|         sep="\n\n",
 | ||
|         sep2="</s>",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| # Baichuan-13B-Chat template
 | ||
| register_conv_template(
 | ||
|     # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
 | ||
|     # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
 | ||
|     Conversation(
 | ||
|         name="baichuan-chat",
 | ||
|         system="",
 | ||
|         roles=(" <reserved_102> ", " <reserved_103> "),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.NO_COLON_TWO,
 | ||
|         sep="",
 | ||
|         sep2="</s>",
 | ||
|         stop_token_ids=[2, 195],
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| # Internlm-chat template
 | ||
| register_conv_template(
 | ||
|     Conversation(
 | ||
|         name="internlm-chat",
 | ||
|         system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
 | ||
|         roles=("<|User|>", "<|Bot|>"),
 | ||
|         messages=(),
 | ||
|         offset=0,
 | ||
|         sep_style=SeparatorStyle.CHATINTERN,
 | ||
|         sep="<eoh>",
 | ||
|         sep2="<eoa>",
 | ||
|         stop_token_ids=[1, 103028],
 | ||
|         stop_str="<eoa>",
 | ||
|     )
 | ||
| )
 | ||
| 
 | ||
| 
 | ||
| # TODO Support other model conversation template
 |