diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 1d56229f0..31c562943 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index d97d2cb2b..05f9ffdcb 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper): return model, tokenizer -class CodeGenAdapter(BaseLLMAdaper): - pass - - class StarCoderAdapter(BaseLLMAdaper): pass -class T5CodeAdapter(BaseLLMAdaper): - pass - - class KoalaLLMAdapter(BaseLLMAdaper): """Koala LLM Adapter which Based LLaMA""" @@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper): """ def match(self, model_path: str): - return "gpt4all" in model_path + return "gptj-6b" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): import gpt4all diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 7a39a8012..3ea6b8206 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -1,23 +1,10 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- -import threading -import sys -import time def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): stop = params.get("stop", "###") prompt = params["prompt"] - role, query = prompt.split(stop)[1].split(":") + role, query = prompt.split(stop)[0].split(":") print(f"gpt4all, role: {role}, query: {query}") - - def worker(): - model.generate(prompt=query, streaming=True) - - t = threading.Thread(target=worker) - t.start() - - while t.is_alive(): - yield sys.stdout.output - time.sleep(0.01) - t.join() + yield model.generate(prompt=query, streaming=True) diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 07b44b28c..0bc53e8fd 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): return chatglm_generate_stream -class CodeT5ChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeT5""" - - def match(self, model_path: str): - return "codet5" in model_path - - def get_generate_stream_func(self, model_path: str): - # TODO - pass - - -class CodeGenChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeGen""" - - def match(self, model_path: str): - return "codegen" in model_path - - def get_generate_stream_func(self, model_path: str): - # TODO - pass - - class GuanacoChatAdapter(BaseChatAdpter): """Model chat adapter for Guanaco""" @@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter): def match(self, model_path: str): - return "gpt4all" in model_path + return "gptj-6b" in model_path def get_generate_stream_func(self, model_path: str): from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index a0c3b6cd8..54e5d0694 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -90,6 +90,7 @@ class ModelWorker: params, model_context = self.llm_chat_adapter.model_adaptation( params, self.ml.model_path, prompt_template=self.ml.prompt_template ) + for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ):