mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 14:34:28 +00:00
lint: fix code style and lint
This commit is contained in:
parent
8e556e3dd3
commit
ee877a63e0
@ -4,6 +4,7 @@ from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
|
||||
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||
|
||||
@ -16,10 +17,15 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
stop_token_ids = [0]
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stop_id in stop_token_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
@ -35,14 +41,13 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
top_k=1,
|
||||
streamer=streamer,
|
||||
repetition_penalty=1.7,
|
||||
stopping_criteria=StoppingCriteriaList([stop])
|
||||
stopping_criteria=StoppingCriteriaList([stop]),
|
||||
)
|
||||
|
||||
|
||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
t1.start()
|
||||
|
||||
generator = model.generate(**generate_kwargs)
|
||||
generator = model.generate(**generate_kwargs)
|
||||
for output in generator:
|
||||
# new_tokens = len(output) - len(input_ids[0])
|
||||
decoded_output = tokenizer.decode(output)
|
||||
@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
out = decoded_output.split("### Response:")[-1].strip()
|
||||
|
||||
yield out
|
||||
|
||||
|
@ -68,8 +68,7 @@ class BaseOutputParser(ABC):
|
||||
return output
|
||||
|
||||
# TODO 后续和模型绑定
|
||||
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||
|
||||
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode())
|
||||
@ -77,7 +76,7 @@ class BaseOutputParser(ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
@ -115,7 +114,6 @@ class BaseOutputParser(ABC):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
parse model out text to prompt define response
|
||||
@ -131,9 +129,9 @@ class BaseOutputParser(ABC):
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
@ -145,7 +143,13 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = m.group(0)
|
||||
else:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", "")
|
||||
.replace("\\n", "")
|
||||
.replace("\\", "")
|
||||
.replace("\\", "")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
|
@ -57,7 +57,14 @@ class BaseChat(ABC):
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
|
||||
def __init__(
|
||||
self,
|
||||
temperature,
|
||||
max_new_tokens,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
self.current_user_input: str = current_user_input
|
||||
@ -68,7 +75,9 @@ class BaseChat(ABC):
|
||||
## TEST
|
||||
self.memory = FileHistoryMemory(chat_session_id)
|
||||
### load prompt template
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||
self.chat_mode.value
|
||||
]
|
||||
self.history_message: List[OnceConversation] = []
|
||||
self.current_message: OnceConversation = OnceConversation()
|
||||
self.current_tokens_used: int = 0
|
||||
@ -129,7 +138,7 @@ class BaseChat(ABC):
|
||||
def stream_call(self):
|
||||
payload = self.__call_base()
|
||||
|
||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
@ -175,29 +184,37 @@ class BaseChat(ABC):
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
response, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
prompt_define_response = (
|
||||
self.prompt_template.output_parser.parse_prompt_response(
|
||||
ai_response_text
|
||||
)
|
||||
)
|
||||
|
||||
result = self.do_with_prompt_response(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if isinstance(prompt_define_response.thoughts, dict):
|
||||
if "speak" in prompt_define_response.thoughts:
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
else:
|
||||
speak_to_user = str(prompt_define_response.thoughts)
|
||||
else:
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||
else:
|
||||
speak_to_user = prompt_define_response.thoughts
|
||||
else:
|
||||
speak_to_user = prompt_define_response
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
||||
view_message = self.prompt_template.output_parser.parse_view_response(
|
||||
speak_to_user, result
|
||||
)
|
||||
self.current_message.add_view_message(view_message)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
@ -226,20 +243,20 @@ class BaseChat(ABC):
|
||||
for first_message in self.history_message[0].messages:
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
text += (
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for last_message in self.history_message[-index:].messages:
|
||||
if not isinstance(last_message, ViewMessage):
|
||||
text += (
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
else:
|
||||
@ -248,16 +265,16 @@ class BaseChat(ABC):
|
||||
for message in conversation.messages:
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += (
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
### current conversation
|
||||
|
||||
for now_message in self.current_message.messages:
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
)
|
||||
|
||||
return text
|
||||
@ -288,4 +305,3 @@ class BaseChat(ABC):
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
||||
class PluginAction(NamedTuple):
|
||||
command: Dict
|
||||
speak: str
|
||||
reasoning:str
|
||||
reasoning: str
|
||||
thoughts: str
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"]
|
||||
command, thoughts, speak, reasoning = (
|
||||
response["command"],
|
||||
response["thoughts"],
|
||||
response["speak"],
|
||||
response["reasoning"],
|
||||
)
|
||||
return PluginAction(command, speak, reasoning, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
|
@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register
|
||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class URLEmbedding(SourceEmbedding):
|
||||
"""url embedding for read url document."""
|
||||
|
||||
@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding):
|
||||
loader = WebBaseLoader(web_path=self.file_path)
|
||||
if CFG.LANGUAGE == "en":
|
||||
text_splitter = CharacterTextSplitter(
|
||||
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len
|
||||
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
|
||||
chunk_overlap=20,
|
||||
length_function=len,
|
||||
)
|
||||
else:
|
||||
text_splitter = CHNDocumentSplitter(
|
||||
pdf=True, sentence_size=1000
|
||||
)
|
||||
text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
|
||||
return loader.load_and_split(text_splitter)
|
||||
|
||||
@register
|
||||
|
Loading…
Reference in New Issue
Block a user