lint: fix code style and lint

This commit is contained in:
csunny 2023-06-01 23:19:45 +08:00
parent 8e556e3dd3
commit ee877a63e0
5 changed files with 75 additions and 44 deletions

View File

@ -4,6 +4,7 @@ from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
"""Fork from: https://github.com/KohakuBlueleaf/guanaco-lora/blob/main/generate.py"""
@ -16,10 +17,15 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
input_ids = tokenizer(query, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
stop_token_ids = [0]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for stop_id in stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
@ -35,14 +41,13 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
top_k=1,
streamer=streamer,
repetition_penalty=1.7,
stopping_criteria=StoppingCriteriaList([stop])
stopping_criteria=StoppingCriteriaList([stop]),
)
t1 = Thread(target=model.generate, kwargs=generate_kwargs)
t1.start()
generator = model.generate(**generate_kwargs)
generator = model.generate(**generate_kwargs)
for output in generator:
# new_tokens = len(output) - len(input_ids[0])
decoded_output = tokenizer.decode(output)
@ -52,4 +57,3 @@ def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
out = decoded_output.split("### Response:")[-1].strip()
yield out

View File

@ -68,8 +68,7 @@ class BaseOutputParser(ABC):
return output
# TODO 后续和模型绑定
def parse_model_stream_resp(self, response, skip_echo_len):
def parse_model_stream_resp(self, response, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
@ -77,7 +76,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
@ -115,7 +114,6 @@ class BaseOutputParser(ABC):
else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_prompt_response(self, model_out_text) -> T:
"""
parse model out text to prompt define response
@ -131,9 +129,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json"):]
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```"):]
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -145,7 +143,13 @@ class BaseOutputParser(ABC):
cleaned_output = m.group(0)
else:
raise ValueError("model server out not fllow the prompt!")
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
cleaned_output = (
cleaned_output.strip()
.replace("\n", "")
.replace("\\n", "")
.replace("\\", "")
.replace("\\", "")
)
return cleaned_output
def parse_view_response(self, ai_text, data) -> str:

View File

@ -57,7 +57,14 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(self,temperature, max_new_tokens, chat_mode, chat_session_id, current_user_input):
def __init__(
self,
temperature,
max_new_tokens,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
self.current_user_input: str = current_user_input
@ -68,7 +75,9 @@ class BaseChat(ABC):
## TEST
self.memory = FileHistoryMemory(chat_session_id)
### load prompt template
self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value]
self.prompt_template: PromptTemplate = CFG.prompt_templates[
self.chat_mode.value
]
self.history_message: List[OnceConversation] = []
self.current_message: OnceConversation = OnceConversation()
self.current_tokens_used: int = 0
@ -129,7 +138,7 @@ class BaseChat(ABC):
def stream_call(self):
payload = self.__call_base()
self.skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 11
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
@ -175,29 +184,37 @@ class BaseChat(ABC):
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(response, self.prompt_template.sep)
self.prompt_template.output_parser.parse_model_nostream_resp(
response, self.prompt_template.sep
)
)
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
prompt_define_response = (
self.prompt_template.output_parser.parse_prompt_response(
ai_response_text
)
)
result = self.do_with_prompt_response(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
if isinstance(prompt_define_response.thoughts, dict):
if isinstance(prompt_define_response.thoughts, dict):
if "speak" in prompt_define_response.thoughts:
speak_to_user = prompt_define_response.thoughts.get("speak")
else:
speak_to_user = str(prompt_define_response.thoughts)
else:
if hasattr(prompt_define_response.thoughts, "speak"):
if hasattr(prompt_define_response.thoughts, "speak"):
speak_to_user = prompt_define_response.thoughts.get("speak")
elif hasattr(prompt_define_response.thoughts, "reasoning"):
elif hasattr(prompt_define_response.thoughts, "reasoning"):
speak_to_user = prompt_define_response.thoughts.get("reasoning")
else:
speak_to_user = prompt_define_response.thoughts
else:
speak_to_user = prompt_define_response
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message)
except Exception as e:
print(traceback.format_exc())
@ -226,20 +243,20 @@ class BaseChat(ABC):
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
else:
@ -248,16 +265,16 @@ class BaseChat(ABC):
for message in conversation.messages:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text
@ -288,4 +305,3 @@ class BaseChat(ABC):
"""
pass

View File

@ -14,14 +14,19 @@ logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
class PluginAction(NamedTuple):
command: Dict
speak: str
reasoning:str
reasoning: str
thoughts: str
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
response = json.loads(super().parse_prompt_response(model_out_text))
command, thoughts, speak, reasoning = response["command"], response["thoughts"], response["speak"], response["reasoning"]
command, thoughts, speak, reasoning = (
response["command"],
response["thoughts"],
response["speak"],
response["reasoning"],
)
return PluginAction(command, speak, reasoning, thoughts)
def parse_view_response(self, speak, data) -> str:

View File

@ -11,6 +11,8 @@ from pilot.source_embedding import SourceEmbedding, register
from pilot.source_embedding.chn_document_splitter import CHNDocumentSplitter
CFG = Config()
class URLEmbedding(SourceEmbedding):
"""url embedding for read url document."""
@ -27,12 +29,12 @@ class URLEmbedding(SourceEmbedding):
loader = WebBaseLoader(web_path=self.file_path)
if CFG.LANGUAGE == "en":
text_splitter = CharacterTextSplitter(
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE, chunk_overlap=20, length_function=len
chunk_size=KNOWLEDGE_CHUNK_SPLIT_SIZE,
chunk_overlap=20,
length_function=len,
)
else:
text_splitter = CHNDocumentSplitter(
pdf=True, sentence_size=1000
)
text_splitter = CHNDocumentSplitter(pdf=True, sentence_size=1000)
return loader.load_and_split(text_splitter)
@register