fix problem

This commit is contained in:
csunny
2023-04-30 23:58:32 +08:00
parent c34e722412
commit 7861fc28ce
5 changed files with 60 additions and 108 deletions

View File

@@ -3,7 +3,7 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple, Any
from typing import List, Any
class SeparatorStyle(Enum):
@@ -29,12 +29,12 @@ class Conversation:
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += self.sep + " " + role + ": " + message
ret += role + ": " + message + self.sep
else:
ret += self.sep + " " + role + ":"
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
@@ -56,7 +56,7 @@ class Conversation:
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
ret.append([msg, None])
else:
@@ -133,19 +133,9 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
conv_template = {
default_conversation = conv_one_shot
conv_templates = {
"conv_one_shot": conv_one_shot,
"vicuna_v1": conv_vicuna_v1
}
def get_default_conv_template(model_name: str = "vicuna-13b"):
model_name = model_name.lower()
if "vicuna" in model_name:
return conv_vicuna_v1
return conv_one_shot
def compute_skip_echo_len(prompt):
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
return skip_echo_len