add gpt4all

This commit is contained in:
ykgong 2023-06-09 11:35:06 +08:00
parent 44bb5135cd
commit 545c323216
6 changed files with 51 additions and 12 deletions

View File

@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = {
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm", "proxyllm": "proxyllm",
} }

View File

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch import torch
import os
from functools import cache
from typing import List from typing import List
from functools import cache from functools import cache
from transformers import ( from transformers import (
@ -185,18 +187,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
class GPT4AllAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper):
"""A light version for someone who want practise LLM use laptop.""" """
A light version for someone who want practise LLM use laptop.
All model names see: https://gpt4all.io/models/models.json
"""
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gpt4all" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
# TODO import gpt4all
pass
if model_path is None and from_pretrained_kwargs.get('model_name') is None:
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
else:
path, file = os.path.split(model_path)
model = gpt4all.GPT4All(model_path=path, model_name=file)
return model, None
class ProxyllmAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy""" """The model adapter for local proxy"""
def match(self, model_path: str): def match(self, model_path: str):
@ -211,6 +221,7 @@ register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter) register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate # TODO Default support vicuna, other model need to tests and Evaluate
# just for test, remove this later # just for test, remove this later

View File

@ -0,0 +1,17 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
print(f"gpt4all, role: {role}, query: {query}")
messages = [{"role": "user", "content": query}]
res = model.chat_completion(messages)
if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \
res.get('choices')[0].get('message').get('content'):
yield res.get('choices')[0].get('message').get('content')
else:
yield "error response"

View File

@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data.get('error_code', 0) == 0:
if "vicuna" in CFG.LLM_MODEL: if "vicuna" in CFG.LLM_MODEL:
# output = data["text"][skip_echo_len + 11:].strip() # output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip() output = data["text"][skip_echo_len:].strip()

View File

@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter): class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna""" """Model chat Adapter for vicuna"""
def match(self, model_path: str): def match(self, model_path: str):
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
class CodeT5ChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5""" """Model chat adapter for CodeT5"""
def match(self, model_path: str): def match(self, model_path: str):
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
class CodeGenChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen""" """Model chat adapter for CodeGen"""
def match(self, model_path: str): def match(self, model_path: str):
@ -127,11 +124,23 @@ class GorillaChatAdapter(BaseChatAdpter):
return generate_stream return generate_stream
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter)
register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter)
register_llm_model_chat_adapter(GorillaChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
# Proxy model for test and develop, it's cheap for us now. # Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter) register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -39,9 +39,9 @@ class ModelWorker:
) )
if not isinstance(self.model, str): if not isinstance(self.model, str):
if hasattr(self.model.config, "max_sequence_length"): if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"): elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings self.context_len = self.model.config.max_position_embeddings
else: else:
@ -66,6 +66,7 @@ class ModelWorker:
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
try: try:
print(f"llmserver params: {params}, self: {self}")
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):