From 545c3232161b11a1e6dfccaa3bb393718a745bc5 Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 11:35:06 +0800 Subject: [PATCH] add gpt4all --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 23 +++++++++++++++++------ pilot/model/llm_out/gpt4all_llm.py | 17 +++++++++++++++++ pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 15 ++++++++++++--- pilot/server/llmserver.py | 5 +++-- 6 files changed, 51 insertions(+), 12 deletions(-) create mode 100644 pilot/model/llm_out/gpt4all_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 36d615043..b85fe6b7b 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), + "ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 7892e4b1b..89ea55ec2 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import torch +import os +from functools import cache from typing import List from functools import cache from transformers import ( @@ -92,8 +94,8 @@ class ChatGLMAdapater(BaseLLMAdaper): AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) - .half() - .cuda() + .half() + .cuda() ) return model, tokenizer @@ -185,18 +187,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): - """A light version for someone who want practise LLM use laptop.""" + """ + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json + """ def match(self, model_path: str): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO - pass + import gpt4all + + if model_path is None and from_pretrained_kwargs.get('model_name') is None: + model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") + else: + path, file = os.path.split(model_path) + model = gpt4all.GPT4All(model_path=path, model_name=file) + return model, None class ProxyllmAdapter(BaseLLMAdaper): - """The model adapter for local proxy""" def match(self, model_path: str): @@ -211,6 +221,7 @@ register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) +register_llm_model_adapters(GPT4AllAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py new file mode 100644 index 000000000..4cc1f067f --- /dev/null +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): + stop = params.get("stop", "###") + prompt = params["prompt"] + role, query = prompt.split(stop)[1].split(":") + print(f"gpt4all, role: {role}, query: {query}") + + messages = [{"role": "user", "content": query}] + res = model.chat_completion(messages) + if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ + res.get('choices')[0].get('message').get('content'): + yield res.get('choices')[0].get('message').get('content') + else: + yield "error response" + diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 513c1d300..6f08d93fe 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data["error_code"] == 0: + if data.get('error_code', 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index e4f57cf46..3598b16b3 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: class VicunaChatAdapter(BaseChatAdpter): - """Model chat Adapter for vicuna""" def match(self, model_path: str): @@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeT5""" def match(self, model_path: str): @@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeGen""" def match(self, model_path: str): @@ -127,11 +124,23 @@ class GorillaChatAdapter(BaseChatAdpter): return generate_stream +class GPT4AllChatAdapter(BaseChatAdpter): + + def match(self, model_path: str): + return "gpt4all" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream + + return gpt4all_generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) +register_llm_model_chat_adapter(GPT4AllChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d2730e0d5..e71872d64 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,9 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): self.context_len = self.model.config.max_position_embeddings else: @@ -66,6 +66,7 @@ class ModelWorker: def generate_stream_gate(self, params): try: + print(f"llmserver params: {params}, self: {self}") for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ):