mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 22:15:35 +00:00
lint: fix code style and lint
This commit is contained in:
parent
8e556e3dd3
commit
ee877a63e0
@ -4,6 +4,7 @@ from threading import Thread
|
|||||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||||
|
|
||||||
|
|
||||||
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||||
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
|
||||||
|
|
||||||
@ -16,15 +17,20 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
input_ids = tokenizer(query, return_tensors="pt").input_ids
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
stop_token_ids = [0]
|
stop_token_ids = [0]
|
||||||
|
|
||||||
class StopOnTokens(StoppingCriteria):
|
class StopOnTokens(StoppingCriteria):
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
def __call__(
|
||||||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||||
|
) -> bool:
|
||||||
for stop_id in stop_token_ids:
|
for stop_id in stop_token_ids:
|
||||||
if input_ids[0][-1] == stop_id:
|
if input_ids[0][-1] == stop_id:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
stop = StopOnTokens()
|
stop = StopOnTokens()
|
||||||
|
|
||||||
generate_kwargs = dict(
|
generate_kwargs = dict(
|
||||||
@ -32,17 +38,16 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
max_new_tokens=512,
|
max_new_tokens=512,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
top_k=1,
|
top_k=1,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
repetition_penalty=1.7,
|
repetition_penalty=1.7,
|
||||||
stopping_criteria=StoppingCriteriaList([stop])
|
stopping_criteria=StoppingCriteriaList([stop]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
t1.start()
|
t1.start()
|
||||||
|
|
||||||
generator = model.generate(**generate_kwargs)
|
generator = model.generate(**generate_kwargs)
|
||||||
for output in generator:
|
for output in generator:
|
||||||
# new_tokens = len(output) - len(input_ids[0])
|
# new_tokens = len(output) - len(input_ids[0])
|
||||||
decoded_output = tokenizer.decode(output)
|
decoded_output = tokenizer.decode(output)
|
||||||
@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
|||||||
out = decoded_output.split("### Response:")[-1].strip()
|
out = decoded_output.split("### Response:")[-1].strip()
|
||||||
|
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
@ -68,8 +68,7 @@ class BaseOutputParser(ABC):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
# TODO 后续和模型绑定
|
# TODO 后续和模型绑定
|
||||||
def parse_model_stream_resp(self, response, skip_echo_len):
|
def parse_model_stream_resp(self, response, skip_echo_len):
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
data = json.loads(chunk.decode())
|
data = json.loads(chunk.decode())
|
||||||
@ -77,7 +76,7 @@ class BaseOutputParser(ABC):
|
|||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
@ -115,7 +114,6 @@ class BaseOutputParser(ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
parse model out text to prompt define response
|
parse model out text to prompt define response
|
||||||
@ -131,9 +129,9 @@ class BaseOutputParser(ABC):
|
|||||||
# if "```" in cleaned_output:
|
# if "```" in cleaned_output:
|
||||||
# cleaned_output, _ = cleaned_output.split("```")
|
# cleaned_output, _ = cleaned_output.split("```")
|
||||||
if cleaned_output.startswith("```json"):
|
if cleaned_output.startswith("```json"):
|
||||||
cleaned_output = cleaned_output[len("```json"):]
|
cleaned_output = cleaned_output[len("```json") :]
|
||||||
if cleaned_output.startswith("```"):
|
if cleaned_output.startswith("```"):
|
||||||
cleaned_output = cleaned_output[len("```"):]
|
cleaned_output = cleaned_output[len("```") :]
|
||||||
if cleaned_output.endswith("```"):
|
if cleaned_output.endswith("```"):
|
||||||
cleaned_output = cleaned_output[: -len("```")]
|
cleaned_output = cleaned_output[: -len("```")]
|
||||||
cleaned_output = cleaned_output.strip()
|
cleaned_output = cleaned_output.strip()
|
||||||
@ -145,7 +143,13 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = m.group(0)
|
cleaned_output = m.group(0)
|
||||||
else:
|
else:
|
||||||
raise ValueError("model server out not fllow the prompt!")
|
raise ValueError("model server out not fllow the prompt!")
|
||||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
cleaned_output = (
|
||||||
|
cleaned_output.strip()
|
||||||
|
.replace("\n", "")
|
||||||
|
.replace("\\n", "")
|
||||||
|
.replace("\\", "")
|
||||||
|
.replace("\\", "")
|
||||||
|
)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(self, ai_text, data) -> str:
|
def parse_view_response(self, ai_text, data) -> str:
|
||||||
|
@ -57,7 +57,14 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature,
|
||||||
|
max_new_tokens,
|
||||||
|
chat_mode,
|
||||||
|
chat_session_id,
|
||||||
|
current_user_input,
|
||||||
|
):
|
||||||
self.chat_session_id = chat_session_id
|
self.chat_session_id = chat_session_id
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
self.current_user_input: str = current_user_input
|
self.current_user_input: str = current_user_input
|
||||||
@ -68,7 +75,9 @@ class BaseChat(ABC):
|
|||||||
## TEST
|
## TEST
|
||||||
self.memory = FileHistoryMemory(chat_session_id)
|
self.memory = FileHistoryMemory(chat_session_id)
|
||||||
### load prompt template
|
### load prompt template
|
||||||
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
|
self.prompt_template: PromptTemplate = CFG.prompt_templates[
|
||||||
|
self.chat_mode.value
|
||||||
|
]
|
||||||
self.history_message: List[OnceConversation] = []
|
self.history_message: List[OnceConversation] = []
|
||||||
self.current_message: OnceConversation = OnceConversation()
|
self.current_message: OnceConversation = OnceConversation()
|
||||||
self.current_tokens_used: int = 0
|
self.current_tokens_used: int = 0
|
||||||
@ -129,7 +138,7 @@ class BaseChat(ABC):
|
|||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
@ -175,29 +184,37 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
### output parse
|
### output parse
|
||||||
ai_response_text = (
|
ai_response_text = (
|
||||||
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
|
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||||
|
response, self.prompt_template.sep
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.current_message.add_ai_message(ai_response_text)
|
self.current_message.add_ai_message(ai_response_text)
|
||||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
prompt_define_response = (
|
||||||
|
self.prompt_template.output_parser.parse_prompt_response(
|
||||||
|
ai_response_text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
result = self.do_with_prompt_response(prompt_define_response)
|
result = self.do_with_prompt_response(prompt_define_response)
|
||||||
|
|
||||||
if hasattr(prompt_define_response, "thoughts"):
|
if hasattr(prompt_define_response, "thoughts"):
|
||||||
if isinstance(prompt_define_response.thoughts, dict):
|
if isinstance(prompt_define_response.thoughts, dict):
|
||||||
if "speak" in prompt_define_response.thoughts:
|
if "speak" in prompt_define_response.thoughts:
|
||||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||||
else:
|
else:
|
||||||
speak_to_user = str(prompt_define_response.thoughts)
|
speak_to_user = str(prompt_define_response.thoughts)
|
||||||
else:
|
else:
|
||||||
if hasattr(prompt_define_response.thoughts, "speak"):
|
if hasattr(prompt_define_response.thoughts, "speak"):
|
||||||
speak_to_user = prompt_define_response.thoughts.get("speak")
|
speak_to_user = prompt_define_response.thoughts.get("speak")
|
||||||
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
elif hasattr(prompt_define_response.thoughts, "reasoning"):
|
||||||
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
speak_to_user = prompt_define_response.thoughts.get("reasoning")
|
||||||
else:
|
else:
|
||||||
speak_to_user = prompt_define_response.thoughts
|
speak_to_user = prompt_define_response.thoughts
|
||||||
else:
|
else:
|
||||||
speak_to_user = prompt_define_response
|
speak_to_user = prompt_define_response
|
||||||
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
|
view_message = self.prompt_template.output_parser.parse_view_response(
|
||||||
|
speak_to_user, result
|
||||||
|
)
|
||||||
self.current_message.add_view_message(view_message)
|
self.current_message.add_view_message(view_message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -226,20 +243,20 @@ class BaseChat(ABC):
|
|||||||
for first_message in self.history_message[0].messages:
|
for first_message in self.history_message[0].messages:
|
||||||
if not isinstance(first_message, ViewMessage):
|
if not isinstance(first_message, ViewMessage):
|
||||||
text += (
|
text += (
|
||||||
first_message.type
|
first_message.type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ first_message.content
|
+ first_message.content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
index = self.chat_retention_rounds - 1
|
index = self.chat_retention_rounds - 1
|
||||||
for last_message in self.history_message[-index:].messages:
|
for last_message in self.history_message[-index:].messages:
|
||||||
if not isinstance(last_message, ViewMessage):
|
if not isinstance(last_message, ViewMessage):
|
||||||
text += (
|
text += (
|
||||||
last_message.type
|
last_message.type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ last_message.content
|
+ last_message.content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -248,16 +265,16 @@ class BaseChat(ABC):
|
|||||||
for message in conversation.messages:
|
for message in conversation.messages:
|
||||||
if not isinstance(message, ViewMessage):
|
if not isinstance(message, ViewMessage):
|
||||||
text += (
|
text += (
|
||||||
message.type
|
message.type
|
||||||
+ ":"
|
+ ":"
|
||||||
+ message.content
|
+ message.content
|
||||||
+ self.prompt_template.sep
|
+ self.prompt_template.sep
|
||||||
)
|
)
|
||||||
### current conversation
|
### current conversation
|
||||||
|
|
||||||
for now_message in self.current_message.messages:
|
for now_message in self.current_message.messages:
|
||||||
text += (
|
text += (
|
||||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||||
)
|
)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
@ -288,4 +305,3 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
|
|||||||
class PluginAction(NamedTuple):
|
class PluginAction(NamedTuple):
|
||||||
command: Dict
|
command: Dict
|
||||||
speak: str
|
speak: str
|
||||||
reasoning:str
|
reasoning: str
|
||||||
thoughts: str
|
thoughts: str
|
||||||
|
|
||||||
|
|
||||||
class PluginChatOutputParser(BaseOutputParser):
|
class PluginChatOutputParser(BaseOutputParser):
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
response = json.loads(super().parse_prompt_response(model_out_text))
|
response = json.loads(super().parse_prompt_response(model_out_text))
|
||||||
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"]
|
command, thoughts, speak, reasoning = (
|
||||||
|
response["command"],
|
||||||
|
response["thoughts"],
|
||||||
|
response["speak"],
|
||||||
|
response["reasoning"],
|
||||||
|
)
|
||||||
return PluginAction(command, speak, reasoning, thoughts)
|
return PluginAction(command, speak, reasoning, thoughts)
|
||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data) -> str:
|
||||||
|
@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register
|
|||||||
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class URLEmbedding(SourceEmbedding):
|
class URLEmbedding(SourceEmbedding):
|
||||||
"""url embedding for read url document."""
|
"""url embedding for read url document."""
|
||||||
|
|
||||||
@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding):
|
|||||||
loader = WebBaseLoader(web_path=self.file_path)
|
loader = WebBaseLoader(web_path=self.file_path)
|
||||||
if CFG.LANGUAGE == "en":
|
if CFG.LANGUAGE == "en":
|
||||||
text_splitter = CharacterTextSplitter(
|
text_splitter = CharacterTextSplitter(
|
||||||
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len
|
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
|
||||||
|
chunk_overlap=20,
|
||||||
|
length_function=len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text_splitter = CHNDocumentSplitter(
|
text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
|
||||||
pdf=True, sentence_size=1000
|
|
||||||
)
|
|
||||||
return loader.load_and_split(text_splitter)
|
return loader.load_and_split(text_splitter)
|
||||||
|
|
||||||
@register
|
@register
|
||||||
|
Loading…
Reference in New Issue
Block a user