mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-15 06:53:12 +00:00
add gpt4all
This commit is contained in:
parent
44bb5135cd
commit
545c323216
@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = {
|
|||||||
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
"guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"),
|
||||||
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
"falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"),
|
||||||
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
|
||||||
|
"ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
|
||||||
"proxyllm": "proxyllm",
|
"proxyllm": "proxyllm",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
from functools import cache
|
||||||
from typing import List
|
from typing import List
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -185,18 +187,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
|
|||||||
|
|
||||||
|
|
||||||
class GPT4AllAdapter(BaseLLMAdaper):
|
class GPT4AllAdapter(BaseLLMAdaper):
|
||||||
"""A light version for someone who want practise LLM use laptop."""
|
"""
|
||||||
|
A light version for someone who want practise LLM use laptop.
|
||||||
|
All model names see: https://gpt4all.io/models/models.json
|
||||||
|
"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "gpt4all" in model_path
|
return "gpt4all" in model_path
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
# TODO
|
import gpt4all
|
||||||
pass
|
|
||||||
|
if model_path is None and from_pretrained_kwargs.get('model_name') is None:
|
||||||
|
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
||||||
|
else:
|
||||||
|
path, file = os.path.split(model_path)
|
||||||
|
model = gpt4all.GPT4All(model_path=path, model_name=file)
|
||||||
|
return model, None
|
||||||
|
|
||||||
|
|
||||||
class ProxyllmAdapter(BaseLLMAdaper):
|
class ProxyllmAdapter(BaseLLMAdaper):
|
||||||
|
|
||||||
"""The model adapter for local proxy"""
|
"""The model adapter for local proxy"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
@ -211,6 +221,7 @@ register_llm_model_adapters(ChatGLMAdapater)
|
|||||||
register_llm_model_adapters(GuanacoAdapter)
|
register_llm_model_adapters(GuanacoAdapter)
|
||||||
register_llm_model_adapters(FalconAdapater)
|
register_llm_model_adapters(FalconAdapater)
|
||||||
register_llm_model_adapters(GorillaAdapter)
|
register_llm_model_adapters(GorillaAdapter)
|
||||||
|
register_llm_model_adapters(GPT4AllAdapter)
|
||||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||||
|
|
||||||
# just for test, remove this later
|
# just for test, remove this later
|
||||||
|
17
pilot/model/llm_out/gpt4all_llm.py
Normal file
17
pilot/model/llm_out/gpt4all_llm.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||||
|
stop = params.get("stop", "###")
|
||||||
|
prompt = params["prompt"]
|
||||||
|
role, query = prompt.split(stop)[1].split(":")
|
||||||
|
print(f"gpt4all, role: {role}, query: {query}")
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": query}]
|
||||||
|
res = model.chat_completion(messages)
|
||||||
|
if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \
|
||||||
|
res.get('choices')[0].get('message').get('content'):
|
||||||
|
yield res.get('choices')[0].get('message').get('content')
|
||||||
|
else:
|
||||||
|
yield "error response"
|
||||||
|
|
@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
|
|||||||
|
|
||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data["error_code"] == 0:
|
if data.get('error_code', 0) == 0:
|
||||||
if "vicuna" in CFG.LLM_MODEL:
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
# output = data["text"][skip_echo_len + 11:].strip()
|
# output = data["text"][skip_echo_len + 11:].strip()
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
|||||||
|
|
||||||
|
|
||||||
class VicunaChatAdapter(BaseChatAdpter):
|
class VicunaChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
"""Model chat Adapter for vicuna"""
|
"""Model chat Adapter for vicuna"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
|
|
||||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
"""Model chat adapter for CodeT5"""
|
"""Model chat adapter for CodeT5"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
|
|
||||||
class CodeGenChatAdapter(BaseChatAdpter):
|
class CodeGenChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
"""Model chat adapter for CodeGen"""
|
"""Model chat adapter for CodeGen"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
@ -127,11 +124,23 @@ class GorillaChatAdapter(BaseChatAdpter):
|
|||||||
return generate_stream
|
return generate_stream
|
||||||
|
|
||||||
|
|
||||||
|
class GPT4AllChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "gpt4all" in model_path
|
||||||
|
|
||||||
|
def get_generate_stream_func(self):
|
||||||
|
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
||||||
|
|
||||||
|
return gpt4all_generate_stream
|
||||||
|
|
||||||
|
|
||||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||||
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
register_llm_model_chat_adapter(GuanacoChatAdapter)
|
||||||
register_llm_model_chat_adapter(FalconChatAdapter)
|
register_llm_model_chat_adapter(FalconChatAdapter)
|
||||||
register_llm_model_chat_adapter(GorillaChatAdapter)
|
register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||||
|
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||||
|
|
||||||
# Proxy model for test and develop, it's cheap for us now.
|
# Proxy model for test and develop, it's cheap for us now.
|
||||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||||
|
@ -39,9 +39,9 @@ class ModelWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
if hasattr(self.model.config, "max_sequence_length"):
|
if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"):
|
||||||
self.context_len = self.model.config.max_sequence_length
|
self.context_len = self.model.config.max_sequence_length
|
||||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"):
|
||||||
self.context_len = self.model.config.max_position_embeddings
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -66,6 +66,7 @@ class ModelWorker:
|
|||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
try:
|
try:
|
||||||
|
print(f"llmserver params: {params}, self: {self}")
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user