mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-09 12:18:12 +00:00
update stream output
This commit is contained in:
parent
fc4fd6aa77
commit
a88e8aa51b
@ -187,8 +187,8 @@ class RWKV4LLMAdapter(BaseLLMAdaper):
|
|||||||
|
|
||||||
class GPT4AllAdapter(BaseLLMAdaper):
|
class GPT4AllAdapter(BaseLLMAdaper):
|
||||||
"""
|
"""
|
||||||
A light version for someone who want practise LLM use laptop.
|
A light version for someone who want practise LLM use laptop.
|
||||||
All model names see: https://gpt4all.io/models/models.json
|
All model names see: https://gpt4all.io/models/models.json
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
@ -197,7 +197,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
|||||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
import gpt4all
|
import gpt4all
|
||||||
|
|
||||||
if model_path is None and from_pretrained_kwargs.get('model_name') is None:
|
if model_path is None and from_pretrained_kwargs.get("model_name") is None:
|
||||||
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy")
|
||||||
else:
|
else:
|
||||||
path, file = os.path.split(model_path)
|
path, file = os.path.split(model_path)
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
import threading
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
|
||||||
stop = params.get("stop", "###")
|
stop = params.get("stop", "###")
|
||||||
@ -7,11 +11,13 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed
|
|||||||
role, query = prompt.split(stop)[1].split(":")
|
role, query = prompt.split(stop)[1].split(":")
|
||||||
print(f"gpt4all, role: {role}, query: {query}")
|
print(f"gpt4all, role: {role}, query: {query}")
|
||||||
|
|
||||||
messages = [{"role": "user", "content": query}]
|
def worker():
|
||||||
res = model.chat_completion(messages)
|
model.generate(prompt=query, streaming=True)
|
||||||
if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \
|
|
||||||
res.get('choices')[0].get('message').get('content'):
|
|
||||||
yield res.get('choices')[0].get('message').get('content')
|
|
||||||
else:
|
|
||||||
yield "error response"
|
|
||||||
|
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
t.start()
|
||||||
|
|
||||||
|
while t.is_alive():
|
||||||
|
yield sys.stdout.output
|
||||||
|
time.sleep(0.1)
|
||||||
|
t.join()
|
||||||
|
@ -51,7 +51,7 @@ class BaseOutputParser(ABC):
|
|||||||
|
|
||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data.get('error_code', 0) == 0:
|
if data.get("error_code", 0) == 0:
|
||||||
if "vicuna" in CFG.LLM_MODEL:
|
if "vicuna" in CFG.LLM_MODEL:
|
||||||
# output = data["text"][skip_echo_len + 11:].strip()
|
# output = data["text"][skip_echo_len + 11:].strip()
|
||||||
output = data["text"][skip_echo_len:].strip()
|
output = data["text"][skip_echo_len:].strip()
|
||||||
|
@ -125,7 +125,6 @@ class GorillaChatAdapter(BaseChatAdpter):
|
|||||||
|
|
||||||
|
|
||||||
class GPT4AllChatAdapter(BaseChatAdpter):
|
class GPT4AllChatAdapter(BaseChatAdpter):
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "gpt4all" in model_path
|
return "gpt4all" in model_path
|
||||||
|
|
||||||
|
@ -39,9 +39,13 @@ class ModelWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"):
|
if hasattr(self.model, "config") and hasattr(
|
||||||
|
self.model.config, "max_sequence_length"
|
||||||
|
):
|
||||||
self.context_len = self.model.config.max_sequence_length
|
self.context_len = self.model.config.max_sequence_length
|
||||||
elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"):
|
elif hasattr(self.model, "config") and hasattr(
|
||||||
|
self.model.config, "max_position_embeddings"
|
||||||
|
):
|
||||||
self.context_len = self.model.config.max_position_embeddings
|
self.context_len = self.model.config.max_position_embeddings
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -69,7 +73,8 @@ class ModelWorker:
|
|||||||
for output in self.generate_stream_func(
|
for output in self.generate_stream_func(
|
||||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||||
):
|
):
|
||||||
print("output: ", output)
|
# 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出
|
||||||
|
# print("output: ", output)
|
||||||
ret = {
|
ret = {
|
||||||
"text": output,
|
"text": output,
|
||||||
"error_code": 0,
|
"error_code": 0,
|
||||||
|
Loading…
Reference in New Issue
Block a user