lint: fix code style and lint

This commit is contained in:
csunny 2023-06-01 23:19:45 +08:00
parent 8e556e3dd3
commit ee877a63e0
5 changed files with 75 additions and 44 deletions

View File

@ -4,6 +4,7 @@ from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py""" """Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
@ -16,10 +17,15 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
input_ids = tokenizer(query, return_tensors="pt").input_ids input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop_token_ids = [0] stop_token_ids = [0]
class StopOnTokens(StoppingCriteria): class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids: for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id: if input_ids[0][-1] == stop_id:
return True return True
@ -35,10 +41,9 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
top_k=1, top_k=1,
streamer=streamer, streamer=streamer,
repetition_penalty=1.7, repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop]) stopping_criteria=StoppingCriteriaList([stop]),
) )
t1 = Thread(target=model.generate, kwargs=generate_kwargs) t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start() t1.start()
@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip() out = decoded_output.split("### Response:")[-1].strip()
yield out yield out

View File

@ -69,7 +69,6 @@ class BaseOutputParser(ABC):
# TODO 后续和模型绑定 # TODO 后续和模型绑定
def parse_model_stream_resp(self, response, skip_echo_len): def parse_model_stream_resp(self, response, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode()) data = json.loads(chunk.decode())
@ -115,7 +114,6 @@ class BaseOutputParser(ABC):
else: else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
parse model out text to prompt define response parse model out text to prompt define response
@ -145,7 +143,13 @@ class BaseOutputParser(ABC):
cleaned_output = m.group(0) cleaned_output = m.group(0)
else: else:
raise ValueError("model server out not fllow the prompt!") raise ValueError("model server out not fllow the prompt!")
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') cleaned_output = (
cleaned_output.strip()
.replace("\n", "")
.replace("\\n", "")
.replace("\\", "")
.replace("\\", "")
)
return cleaned_output return cleaned_output
def parse_view_response(self, ai_text, data) -> str: def parse_view_response(self, ai_text, data) -> str:

View File

@ -57,7 +57,14 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input): def __init__(
self,
temperature,
max_new_tokens,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
self.current_user_input: str = current_user_input self.current_user_input: str = current_user_input
@ -68,7 +75,9 @@ class BaseChat(ABC):
## TEST ## TEST
self.memory = FileHistoryMemory(chat_session_id) self.memory = FileHistoryMemory(chat_session_id)
### load prompt template ### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = [] self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation() self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0 self.current_tokens_used: int = 0
@ -129,7 +138,7 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -175,10 +184,16 @@ class BaseChat(ABC):
### output parse ### output parse
ai_response_text = ( ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep) self.prompt_template.output_parser.parse_model_nostream_resp(
response, self.prompt_template.sep
)
) )
self.current_message.add_ai_message(ai_response_text) self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_with_prompt_response(prompt_define_response) result = self.do_with_prompt_response(prompt_define_response)
@ -197,7 +212,9 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response.thoughts speak_to_user = prompt_define_response.thoughts
else: else:
speak_to_user = prompt_define_response speak_to_user = prompt_define_response
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -288,4 +305,3 @@ class BaseChat(ABC):
""" """
pass pass

View File

@ -21,7 +21,12 @@ class PluginAction(NamedTuple):
class PluginChatOutputParser(BaseOutputParser): class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text)) response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"] command, thoughts, speak, reasoning = (
response["command"],
response["thoughts"],
response["speak"],
response["reasoning"],
)
return PluginAction(command, speak, reasoning, thoughts) return PluginAction(command, speak, reasoning, thoughts)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:

View File

@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config() CFG = Config()
class URLEmbedding(SourceEmbedding): class URLEmbedding(SourceEmbedding):
"""url embedding for read url document.""" """url embedding for read url document."""
@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding):
loader = WebBaseLoader(web_path=self.file_path) loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en": if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter( text_splitter = CharacterTextSplitter(
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
chunk_overlap=20,
length_function=len,
) )
else: else:
text_splitter = CHNDocumentSplitter( text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
pdf=True, sentence_size=1000
)
return loader.load_and_split(text_splitter) return loader.load_and_split(text_splitter)
@register @register