Dev gpt4all (#178)

add model from gpt4all @csunny
This commit is contained in:
magic.chen 2023-06-13 17:59:26 +08:00 committed by GitHub
commit ad8e0e353a
7 changed files with 61 additions and 11 deletions

View File

@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm",
}

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
import torch
import os
from typing import List
from functools import cache
from transformers import (
@ -185,18 +186,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop."""
"""
A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json
"""
def match(self, model_path: str):
return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO
pass
import gpt4all
if model_path is None and from_pretrained_kwargs.get("model_name") is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else:
path, file = os.path.split(model_path)
model = gpt4all.GPT4All(model_path=path, model_name=file)
return model, None
class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy"""
def match(self, model_path: str):
@ -211,6 +220,7 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}")
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()

View File

@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna"""
def match(self, model_path: str):
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5"""
def match(self, model_path: str):
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen"""
def match(self, model_path: str):
@ -127,11 +124,22 @@ class GorillaChatAdapter(BaseChatAdpter):
return generate_stream
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -39,9 +39,13 @@ class ModelWorker:
)
if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
else:
@ -69,7 +73,10 @@ class ModelWorker:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
print("output: ", output)
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
# print("output: ", output)
ret = {
"text": output,
"error_code": 0,

View File

@ -49,6 +49,7 @@ llama-index==0.5.27
pymysql
unstructured==0.6.3
grpcio==1.47.5
gpt4all==0.3.0
auto-gpt-plugin-template
pymdown-extensions