From ee877a63e0a6382cda4c555ecece5b6f5738264c Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 1 Jun 2023 23:19:45 +0800 Subject: [PATCH] lint: fix code style and lint --- pilot/model/llm_out/guanaco_llm.py | 20 +++++--- pilot/out_parser/base.py | 18 ++++--- pilot/scene/base_chat.py | 62 +++++++++++++++--------- pilot/scene/chat_execution/out_parser.py | 9 +++- pilot/source_embedding/url_embedding.py | 10 ++-- 5 files changed, 75 insertions(+), 44 deletions(-) diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 0ed42cf65..37c4c423b 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -4,6 +4,7 @@ from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from pilot.conversation import ROLE_ASSISTANT, ROLE_USER + def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" @@ -16,15 +17,20 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): input_ids = tokenizer(query, return_tensors="pt").input_ids input_ids = input_ids.to(model.device) - streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + streamer = TextIteratorStreamer( + tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True + ) stop_token_ids = [0] + class StopOnTokens(StoppingCriteria): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False - + stop = StopOnTokens() generate_kwargs = dict( @@ -32,17 +38,16 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): max_new_tokens=512, temperature=1.0, do_sample=True, - top_k=1, + top_k=1, streamer=streamer, repetition_penalty=1.7, - stopping_criteria=StoppingCriteriaList([stop]) + stopping_criteria=StoppingCriteriaList([stop]), ) - t1 = Thread(target=model.generate, kwargs=generate_kwargs) t1.start() - generator = model.generate(**generate_kwargs) + generator = model.generate(**generate_kwargs) for output in generator: # new_tokens = len(output) - len(input_ids[0]) decoded_output = tokenizer.decode(output) @@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): out = decoded_output.split("### Response:")[-1].strip() yield out - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 46a7dde8b..0538aa54c 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -68,8 +68,7 @@ class BaseOutputParser(ABC): return output # TODO 后续和模型绑定 - def parse_model_stream_resp(self, response, skip_echo_len): - + def parse_model_stream_resp(self, response, skip_echo_len): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) @@ -77,7 +76,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: + if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL: output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() @@ -115,7 +114,6 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_prompt_response(self, model_out_text) -> T: """ parse model out text to prompt define response @@ -131,9 +129,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] + cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] + cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -145,7 +143,13 @@ class BaseOutputParser(ABC): cleaned_output = m.group(0) else: raise ValueError("model server out not fllow the prompt!") - cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') + cleaned_output = ( + cleaned_output.strip() + .replace("\n", "") + .replace("\\n", "") + .replace("\\", "") + .replace("\\", "") + ) return cleaned_output def parse_view_response(self, ai_text, data) -> str: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 84eeb3c15..c1d831d0d 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -57,7 +57,14 @@ class BaseChat(ABC): arbitrary_types_allowed = True - def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): + def __init__( + self, + temperature, + max_new_tokens, + chat_mode, + chat_session_id, + current_user_input, + ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode self.current_user_input: str = current_user_input @@ -68,7 +75,9 @@ class BaseChat(ABC): ## TEST self.memory = FileHistoryMemory(chat_session_id) ### load prompt template - self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] + self.prompt_template: PromptTemplate = CFG.prompt_templates[ + self.chat_mode.value + ] self.history_message: List[OnceConversation] = [] self.current_message: OnceConversation = OnceConversation() self.current_tokens_used: int = 0 @@ -129,7 +138,7 @@ class BaseChat(ABC): def stream_call(self): payload = self.__call_base() - self.skip_echo_len = len(payload.get('prompt').replace("", " ")) + 11 + self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Requert: \n{payload}") ai_response_text = "" try: @@ -175,29 +184,37 @@ class BaseChat(ABC): ### output parse ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep) + self.prompt_template.output_parser.parse_model_nostream_resp( + response, self.prompt_template.sep + ) ) self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_with_prompt_response(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): - if isinstance(prompt_define_response.thoughts, dict): + if isinstance(prompt_define_response.thoughts, dict): if "speak" in prompt_define_response.thoughts: speak_to_user = prompt_define_response.thoughts.get("speak") else: speak_to_user = str(prompt_define_response.thoughts) else: - if hasattr(prompt_define_response.thoughts, "speak"): + if hasattr(prompt_define_response.thoughts, "speak"): speak_to_user = prompt_define_response.thoughts.get("speak") - elif hasattr(prompt_define_response.thoughts, "reasoning"): + elif hasattr(prompt_define_response.thoughts, "reasoning"): speak_to_user = prompt_define_response.thoughts.get("reasoning") else: speak_to_user = prompt_define_response.thoughts else: speak_to_user = prompt_define_response - view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) + view_message = self.prompt_template.output_parser.parse_view_response( + speak_to_user, result + ) self.current_message.add_view_message(view_message) except Exception as e: print(traceback.format_exc()) @@ -226,20 +243,20 @@ class BaseChat(ABC): for first_message in self.history_message[0].messages: if not isinstance(first_message, ViewMessage): text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep + first_message.type + + ":" + + first_message.content + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for last_message in self.history_message[-index:].messages: if not isinstance(last_message, ViewMessage): text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep + last_message.type + + ":" + + last_message.content + + self.prompt_template.sep ) else: @@ -248,16 +265,16 @@ class BaseChat(ABC): for message in conversation.messages: if not isinstance(message, ViewMessage): text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep + message.type + + ":" + + message.content + + self.prompt_template.sep ) ### current conversation for now_message in self.current_message.messages: text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep + now_message.type + ":" + now_message.content + self.prompt_template.sep ) return text @@ -288,4 +305,3 @@ class BaseChat(ABC): """ pass - diff --git a/pilot/scene/chat_execution/out_parser.py b/pilot/scene/chat_execution/out_parser.py index 2f8dd00b7..7b7abbc09 100644 --- a/pilot/scene/chat_execution/out_parser.py +++ b/pilot/scene/chat_execution/out_parser.py @@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log") class PluginAction(NamedTuple): command: Dict speak: str - reasoning:str + reasoning: str thoughts: str class PluginChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: response = json.loads(super().parse_prompt_response(model_out_text)) - command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"] + command, thoughts, speak, reasoning = ( + response["command"], + response["thoughts"], + response["speak"], + response["reasoning"], + ) return PluginAction(command, speak, reasoning, thoughts) def parse_view_response(self, speak, data) -> str: diff --git a/pilot/source_embedding/url_embedding.py b/pilot/source_embedding/url_embedding.py index 7acfaf961..39224a9f4 100644 --- a/pilot/source_embedding/url_embedding.py +++ b/pilot/source_embedding/url_embedding.py @@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter CFG = Config() + + class URLEmbedding(SourceEmbedding): """url embedding for read url document.""" @@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding): loader = WebBaseLoader(web_path=self.file_path) if CFG.LANGUAGE == "en": text_splitter = CharacterTextSplitter( - chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len + chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, + chunk_overlap=20, + length_function=len, ) else: - text_splitter = CHNDocumentSplitter( - pdf=True, sentence_size=1000 - ) + text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000) return loader.load_and_split(text_splitter) @register