mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
commit
4ec47a5efd
Binary file not shown.
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 169 KiB |
@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class CodeGenAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class StarCoderAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class T5CodeAdapter(BaseLLMAdaper):
|
||||
pass
|
||||
|
||||
|
||||
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||
"""Koala LLM Adapter which Based LLaMA"""
|
||||
|
||||
@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
||||
"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
return "gptj-6b" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
import gpt4all
|
||||
|
@ -1,23 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
import threading
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||
stop = params.get("stop", "###")
|
||||
prompt = params["prompt"]
|
||||
role, query = prompt.split(stop)[1].split(":")
|
||||
role, query = prompt.split(stop)[0].split(":")
|
||||
print(f"gpt4all, role: {role}, query: {query}")
|
||||
|
||||
def worker():
|
||||
model.generate(prompt=query, streaming=True)
|
||||
|
||||
t = threading.Thread(target=worker)
|
||||
t.start()
|
||||
|
||||
while t.is_alive():
|
||||
yield sys.stdout.output
|
||||
time.sleep(0.01)
|
||||
t.join()
|
||||
yield model.generate(prompt=query, streaming=True)
|
||||
|
@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for CodeT5"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class CodeGenChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for CodeGen"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class GuanacoChatAdapter(BaseChatAdpter):
|
||||
"""Model chat adapter for Guanaco"""
|
||||
|
||||
@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter):
|
||||
|
||||
class GPT4AllChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
return "gptj-6b" in model_path
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
||||
|
@ -90,6 +90,7 @@ class ModelWorker:
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
)
|
||||
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
|
Loading…
Reference in New Issue
Block a user