lint: fix code style and lint

This commit is contained in:
csunny 2023-06-01 23:19:45 +08:00
parent 8e556e3dd3
commit ee877a63e0
5 changed files with 75 additions and 44 deletions

View File

@ -4,6 +4,7 @@ from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
@ -16,15 +17,20 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
input_ids = tokenizer(query, return_tensors="pt").input_ids input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop_token_ids = [0] stop_token_ids = [0]
class StopOnTokens(StoppingCriteria): class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids: for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id: if input_ids[0][-1] == stop_id:
return True return True
return False return False
stop = StopOnTokens() stop = StopOnTokens()
generate_kwargs = dict( generate_kwargs = dict(
@ -32,17 +38,16 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
max_new_tokens=512, max_new_tokens=512,
temperature=1.0, temperature=1.0,
do_sample=True, do_sample=True,
top_k=1, top_k=1,
streamer=streamer, streamer=streamer,
repetition_penalty=1.7, repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]) stopping_criteria=StoppingCriteriaList([stop]),
) )
t1 = Thread(target=model.generate, kwargs=generate_kwargs) t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start() t1.start()
generator = model.generate(**generate_kwargs) generator = model.generate(**generate_kwargs)
for output in generator: for output in generator:
# new_tokens = len(output) - len(input_ids[0]) # new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output) decoded_output = tokenizer.decode(output)
@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip() out = decoded_output.split("### Response:")[-1].strip()
yield out yield out

View File

@ -68,8 +68,7 @@ class BaseOutputParser(ABC):
return output return output
# TODO 后续和模型绑定 # TODO 后续和模型绑定
def parse_model_stream_resp(self, response, skip_echo_len): def parse_model_stream_resp(self, response, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
@ -77,7 +76,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()
else: else:
output = data["text"].strip() output = data["text"].strip()
@ -115,7 +114,6 @@ class BaseOutputParser(ABC):
else: else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
parse model out text to prompt define response parse model out text to prompt define response
@ -131,9 +129,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output: # if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```") # cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"): if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):] cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"): if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):] cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"): if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip() cleaned_output = cleaned_output.strip()
@ -145,7 +143,13 @@ class BaseOutputParser(ABC):
cleaned_output = m.group(0) cleaned_output = m.group(0)
else: else:
raise ValueError("model server out not fllow the prompt!") raise ValueError("model server out not fllow the prompt!")
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') cleaned_output = (
cleaned_output.strip()
.replace("\n", "")
.replace("\\n", "")
.replace("\\", "")
.replace("\\", "")
)
return cleaned_output return cleaned_output
def parse_view_response(self, ai_text, data) -> str: def parse_view_response(self, ai_text, data) -> str:

View File

@ -57,7 +57,14 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): def __init__(
self,
temperature,
max_new_tokens,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
self.current_user_input: str = current_user_input self.current_user_input: str = current_user_input
@ -68,7 +75,9 @@ class BaseChat(ABC):
## TEST ## TEST
self.memory = FileHistoryMemory(chat_session_id) self.memory = FileHistoryMemory(chat_session_id)
### load prompt template ### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = [] self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation() self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
@ -129,7 +138,7 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -175,29 +184,37 @@ class BaseChat(ABC):
### output parse ### output parse
ai_response_text = ( ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep) self.prompt_template.output_parser.parse_model_nostream_resp(
response, self.prompt_template.sep
)
) )
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_with_prompt_response(prompt_define_response) result = self.do_with_prompt_response(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"): if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict): if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts: if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak") speak_to_user = prompt_define_response.thoughts.get("speak")
else: else:
speak_to_user = str(prompt_define_response.thoughts) speak_to_user = str(prompt_define_response.thoughts)
else: else:
if hasattr(prompt_define_response.thoughts, "speak"): if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak") speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"): elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning") speak_to_user = prompt_define_response.thoughts.get("reasoning")
else: else:
speak_to_user = prompt_define_response.thoughts speak_to_user = prompt_define_response.thoughts
else: else:
speak_to_user = prompt_define_response speak_to_user = prompt_define_response
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -226,20 +243,20 @@ class BaseChat(ABC):
for first_message in self.history_message[0].messages: for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage): if not isinstance(first_message, ViewMessage):
text += ( text += (
first_message.type first_message.type
+ ":" + ":"
+ first_message.content + first_message.content
+ self.prompt_template.sep + self.prompt_template.sep
) )
index = self.chat_retention_rounds - 1 index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages: for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage): if not isinstance(last_message, ViewMessage):
text += ( text += (
last_message.type last_message.type
+ ":" + ":"
+ last_message.content + last_message.content
+ self.prompt_template.sep + self.prompt_template.sep
) )
else: else:
@ -248,16 +265,16 @@ class BaseChat(ABC):
for message in conversation.messages: for message in conversation.messages:
if not isinstance(message, ViewMessage): if not isinstance(message, ViewMessage):
text += ( text += (
message.type message.type
+ ":" + ":"
+ message.content + message.content
+ self.prompt_template.sep + self.prompt_template.sep
) )
### current conversation ### current conversation
for now_message in self.current_message.messages: for now_message in self.current_message.messages:
text += ( text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep now_message.type + ":" + now_message.content + self.prompt_template.sep
) )
return text return text
@ -288,4 +305,3 @@ class BaseChat(ABC):
""" """
pass pass

View File

@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple): class PluginAction(NamedTuple):
command: Dict command: Dict
speak: str speak: str
reasoning:str reasoning: str
thoughts: str thoughts: str
class PluginChatOutputParser(BaseOutputParser): class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text)) response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"] command, thoughts, speak, reasoning = (
response["command"],
response["thoughts"],
response["speak"],
response["reasoning"],
)
return PluginAction(command, speak, reasoning, thoughts) return PluginAction(command, speak, reasoning, thoughts)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:

View File

@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()
class URLEmbedding(SourceEmbedding): class URLEmbedding(SourceEmbedding):
"""url embedding for read url document.""" """url embedding for read url document."""
@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding):
loader = WebBaseLoader(web_path=self.file_path) loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en": if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter( text_splitter = CharacterTextSplitter(
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
chunk_overlap=20,
length_function=len,
) )
else: else:
text_splitter = CHNDocumentSplitter( text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
pdf=True, sentence_size=1000
)
return loader.load_and_split(text_splitter) return loader.load_and_split(text_splitter)
@register @register