mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-08 11:47:44 +00:00
commit
4ec47a5efd
Binary file not shown.
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 169 KiB |
@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
class CodeGenAdapter(BaseLLMAdaper):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StarCoderAdapter(BaseLLMAdaper):
|
class StarCoderAdapter(BaseLLMAdaper):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class T5CodeAdapter(BaseLLMAdaper):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class KoalaLLMAdapter(BaseLLMAdaper):
|
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||||
"""Koala LLM Adapter which Based LLaMA"""
|
"""Koala LLM Adapter which Based LLaMA"""
|
||||||
|
|
||||||
@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "gpt4all" in model_path
|
return "gptj-6b" in model_path
|
||||||
|
|
||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
import gpt4all
|
import gpt4all
|
||||||
|
@ -1,23 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
import threading
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
role, query = prompt.split(stop)[1].split(":")
|
role, query = prompt.split(stop)[0].split(":")
|
||||||
print(f"gpt4all, role: {role}, query: {query}")
|
print(f"gpt4all, role: {role}, query: {query}")
|
||||||
|
yield model.generate(prompt=query, streaming=True)
|
||||||
def worker():
|
|
||||||
model.generate(prompt=query, streaming=True)
|
|
||||||
|
|
||||||
t = threading.Thread(target=worker)
|
|
||||||
t.start()
|
|
||||||
|
|
||||||
while t.is_alive():
|
|
||||||
yield sys.stdout.output
|
|
||||||
time.sleep(0.01)
|
|
||||||
t.join()
|
|
||||||
|
@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
|||||||
return chatglm_generate_stream
|
return chatglm_generate_stream
|
||||||
|
|
||||||
|
|
||||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
|
||||||
"""Model chat adapter for CodeT5"""
|
|
||||||
|
|
||||||
def match(self, model_path: str):
|
|
||||||
return "codet5" in model_path
|
|
||||||
|
|
||||||
def get_generate_stream_func(self, model_path: str):
|
|
||||||
# TODO
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CodeGenChatAdapter(BaseChatAdpter):
|
|
||||||
"""Model chat adapter for CodeGen"""
|
|
||||||
|
|
||||||
def match(self, model_path: str):
|
|
||||||
return "codegen" in model_path
|
|
||||||
|
|
||||||
def get_generate_stream_func(self, model_path: str):
|
|
||||||
# TODO
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class GuanacoChatAdapter(BaseChatAdpter):
|
class GuanacoChatAdapter(BaseChatAdpter):
|
||||||
"""Model chat adapter for Guanaco"""
|
"""Model chat adapter for Guanaco"""
|
||||||
|
|
||||||
@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
class GPT4AllChatAdapter(BaseChatAdpter):
|
class GPT4AllChatAdapter(BaseChatAdpter):
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "gpt4all" in model_path
|
return "gptj-6b" in model_path
|
||||||
|
|
||||||
def get_generate_stream_func(self, model_path: str):
|
def get_generate_stream_func(self, model_path: str):
|
||||||
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
||||||
|
@ -90,6 +90,7 @@ class ModelWorker:
|
|||||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||||
)
|
)
|
||||||
|
|
||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user