mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
fix problem
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Tuple, Any
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
@@ -29,12 +29,12 @@ class Conversation:
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += self.sep + " " + role + ": " + message
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += self.sep + " " + role + ":"
|
||||
ret += role + ":"
|
||||
return ret
|
||||
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
@@ -56,7 +56,7 @@ class Conversation:
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
@@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
conv_template = {
|
||||
default_conversation = conv_one_shot
|
||||
|
||||
conv_templates = {
|
||||
"conv_one_shot": conv_one_shot,
|
||||
"vicuna_v1": conv_vicuna_v1
|
||||
}
|
||||
|
||||
|
||||
def get_default_conv_template(model_name: str = "vicuna-13b"):
|
||||
model_name = model_name.lower()
|
||||
if "vicuna" in model_name:
|
||||
return conv_vicuna_v1
|
||||
return conv_one_shot
|
||||
|
||||
|
||||
def compute_skip_echo_len(prompt):
|
||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
||||
return skip_echo_len
|
Reference in New Issue
Block a user