update stream output

This commit is contained in:
sheri528 2023-06-13 14:11:23 +08:00
parent fc4fd6aa77
commit a88e8aa51b
5 changed files with 25 additions and 15 deletions

View File

@ -187,8 +187,8 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper):
""" """
A light version for someone who want practise LLM use laptop. A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json All model names see: https://gpt4all.io/models/models.json
""" """
def match(self, model_path: str): def match(self, model_path: str):
@ -197,7 +197,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
import gpt4all import gpt4all
if model_path is None and from_pretrained_kwargs.get('model_name') is None: if model_path is None and from_pretrained_kwargs.get("model_name") is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else: else:
path, file = os.path.split(model_path) path, file = os.path.split(model_path)

View File

@ -1,5 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###") stop = params.get("stop", "###")
@ -7,11 +11,13 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed
role, query = prompt.split(stop)[1].split(":") role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}") print(f"gpt4all, role: {role}, query: {query}")
messages = [{"role": "user", "content": query}] def worker():
res = model.chat_completion(messages) model.generate(prompt=query, streaming=True)
if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \
res.get('choices')[0].get('message').get('content'):
yield res.get('choices')[0].get('message').get('content')
else:
yield "error response"
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.1)
t.join()

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data.get('error_code', 0) == 0: if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()

View File

@ -125,7 +125,6 @@ class GorillaChatAdapter(BaseChatAdpter):
class GPT4AllChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gpt4all" in model_path

View File

@ -39,9 +39,13 @@ class ModelWorker:
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
else: else:
@ -69,7 +73,8 @@ class ModelWorker:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
print("output: ", output) # 生产请不要打开输出gpt4all线程与父进程共享stdout 打开会影响前端输出
# print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,