Dev gpt4all (#178)

add model from gpt4all @csunny
This commit is contained in:
magic.chen 2023-06-13 17:59:26 +08:00 committed by GitHub
commit ad8e0e353a
7 changed files with 61 additions and 11 deletions

View File

@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch import torch
import os
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import ( from transformers import (
@ -185,18 +186,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop.""" """
A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json
"""
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO import gpt4all
pass
if model_path is None and from_pretrained_kwargs.get("model_name") is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else:
path, file = os.path.split(model_path)
model = gpt4all.GPT4All(model_path=path, model_name=file)
return model, None
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
@ -211,6 +220,7 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later # just for test, remove this later

View File

@ -0,0 +1,23 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}")
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data.get("error_code", 0) == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()

View File

@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter): class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna""" """Model chat Adapter for vicuna"""
def match(self, model_path: str): def match(self, model_path: str):
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
class CodeT5ChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5""" """Model chat adapter for CodeT5"""
def match(self, model_path: str): def match(self, model_path: str):
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
class CodeGenChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen""" """Model chat adapter for CodeGen"""
def match(self, model_path: str): def match(self, model_path: str):
@ -127,11 +124,22 @@ class GorillaChatAdapter(BaseChatAdpter):
return generate_stream return generate_stream
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -39,9 +39,13 @@ class ModelWorker:
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"): elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
else: else:
@ -69,7 +73,10 @@ class ModelWorker:
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):
print("output: ", output) # Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
# print("output: ", output)
ret = { ret = {
"text": output, "text": output,
"error_code": 0, "error_code": 0,

View File

@ -49,6 +49,7 @@ llama-index==0.5.27
pymysql pymysql
unstructured==0.6.3 unstructured==0.6.3
grpcio==1.47.5 grpcio==1.47.5
gpt4all==0.3.0
auto-gpt-plugin-template auto-gpt-plugin-template
pymdown-extensions pymdown-extensions