From a88e8aa51bf27bf5540655d73bd579cf7cfae7bd Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:11:23 +0800 Subject: [PATCH] update stream output --- pilot/model/adapter.py | 6 +++--- pilot/model/llm_out/gpt4all_llm.py | 20 +++++++++++++------- pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 1 - pilot/server/llmserver.py | 11 ++++++++--- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 435be1142..407d11127 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -187,8 +187,8 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): """ - A light version for someone who want practise LLM use laptop. - All model names see: https://gpt4all.io/models/models.json + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json """ def match(self, model_path: str): @@ -197,7 +197,7 @@ class GPT4AllAdapter(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwargs: dict): import gpt4all - if model_path is None and from_pretrained_kwargs.get('model_name') is None: + if model_path is None and from_pretrained_kwargs.get("model_name") is None: model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") else: path, file = os.path.split(model_path) diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 4cc1f067f..5ea72f911 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import threading +import sys +import time + def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): stop = params.get("stop", "###") @@ -7,11 +11,13 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed role, query = prompt.split(stop)[1].split(":") print(f"gpt4all, role: {role}, query: {query}") - messages = [{"role": "user", "content": query}] - res = model.chat_completion(messages) - if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ - res.get('choices')[0].get('message').get('content'): - yield res.get('choices')[0].get('message').get('content') - else: - yield "error response" + def worker(): + model.generate(prompt=query, streaming=True) + t = threading.Thread(target=worker) + t.start() + + while t.is_alive(): + yield sys.stdout.output + time.sleep(0.1) + t.join() diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 6f08d93fe..6406f30dd 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data.get('error_code', 0) == 0: + if data.get("error_code", 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 3598b16b3..ebab2d2d4 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -125,7 +125,6 @@ class GorillaChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter): - def match(self, model_path: str): return "gpt4all" in model_path diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index ad4627afa..66180a406 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,13 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr( + self.model.config, "max_sequence_length" + ): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr( + self.model.config, "max_position_embeddings" + ): self.context_len = self.model.config.max_position_embeddings else: @@ -69,7 +73,8 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - print("output: ", output) + # 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出 + # print("output: ", output) ret = { "text": output, "error_code": 0,