diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 36d615043..851a0486d 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), + "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 7892e4b1b..01d05837b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -2,6 +2,7 @@ # -*- coding: utf-8 -*- import torch +import os from typing import List from functools import cache from transformers import ( @@ -185,18 +186,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): - """A light version for someone who want practise LLM use laptop.""" + """ + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json + """ def match(self, model_path: str): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO - pass + import gpt4all + + if model_path is None and from_pretrained_kwargs.get("model_name") is None: + model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") + else: + path, file = os.path.split(model_path) + model = gpt4all.GPT4All(model_path=path, model_name=file) + return model, None class ProxyllmAdapter(BaseLLMAdaper): - """The model adapter for local proxy""" def match(self, model_path: str): @@ -211,6 +220,7 @@ register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) +register_llm_model_adapters(GPT4AllAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py new file mode 100644 index 000000000..7a39a8012 --- /dev/null +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +import threading +import sys +import time + + +def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): + stop = params.get("stop", "###") + prompt = params["prompt"] + role, query = prompt.split(stop)[1].split(":") + print(f"gpt4all, role: {role}, query: {query}") + + def worker(): + model.generate(prompt=query, streaming=True) + + t = threading.Thread(target=worker) + t.start() + + while t.is_alive(): + yield sys.stdout.output + time.sleep(0.01) + t.join() diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 513c1d300..6406f30dd 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data["error_code"] == 0: + if data.get("error_code", 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index e4f57cf46..ebab2d2d4 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: class VicunaChatAdapter(BaseChatAdpter): - """Model chat Adapter for vicuna""" def match(self, model_path: str): @@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeT5""" def match(self, model_path: str): @@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeGen""" def match(self, model_path: str): @@ -127,11 +124,22 @@ class GorillaChatAdapter(BaseChatAdpter): return generate_stream +class GPT4AllChatAdapter(BaseChatAdpter): + def match(self, model_path: str): + return "gpt4all" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream + + return gpt4all_generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) +register_llm_model_chat_adapter(GPT4AllChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d2730e0d5..30653a16e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,13 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr( + self.model.config, "max_sequence_length" + ): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr( + self.model.config, "max_position_embeddings" + ): self.context_len = self.model.config.max_position_embeddings else: @@ -69,7 +73,10 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - print("output: ", output) + # Please do not open the output in production! + # The gpt4all thread shares stdout with the parent process, + # and opening it may affect the frontend output. + # print("output: ", output) ret = { "text": output, "error_code": 0, diff --git a/requirements.txt b/requirements.txt index 9238751ca..c6434c3ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ llama-index==0.5.27 pymysql unstructured==0.6.3 grpcio==1.47.5 +gpt4all==0.3.0 auto-gpt-plugin-template pymdown-extensions