fix: gpt4all output bug (#478)

Close #443
This commit is contained in:
Aries-ckt 2023-08-23 17:22:39 +08:00 committed by GitHub
commit 4ec47a5efd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 47 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 169 KiB

View File

@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper):
return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass
class StarCoderAdapter(BaseLLMAdaper):
pass
class T5CodeAdapter(BaseLLMAdaper):
pass
class KoalaLLMAdapter(BaseLLMAdaper):
"""Koala LLM Adapter which Based LLaMA"""
@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
"""
def match(self, model_path: str):
return "gpt4all" in model_path
return "gptj-6b" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict):
import gpt4all

View File

@ -1,23 +1,10 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###")
prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":")
role, query = prompt.split(stop)[0].split(":")
print(f"gpt4all, role: {role}, query: {query}")
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()
yield model.generate(prompt=query, streaming=True)

View File

@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5"""
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self, model_path: str):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen"""
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self, model_path: str):
# TODO
pass
class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco"""
@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter):
class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
return "gptj-6b" in model_path
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream

View File

@ -90,6 +90,7 @@ class ModelWorker:
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):