fix: gpt4all output bug (#478)

Close #443
This commit is contained in:
Aries-ckt 2023-08-23 17:22:39 +08:00 committed by GitHub
commit 4ec47a5efd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 47 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 169 KiB

View File

@ -233,18 +233,10 @@ class GorillaAdapter(BaseLLMAdaper):
return model, tokenizer return model, tokenizer
class CodeGenAdapter(BaseLLMAdaper):
pass
class StarCoderAdapter(BaseLLMAdaper): class StarCoderAdapter(BaseLLMAdaper):
pass pass
class T5CodeAdapter(BaseLLMAdaper):
pass
class KoalaLLMAdapter(BaseLLMAdaper): class KoalaLLMAdapter(BaseLLMAdaper):
"""Koala LLM Adapter which Based LLaMA""" """Koala LLM Adapter which Based LLaMA"""
@ -270,7 +262,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
""" """
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gptj-6b" in model_path
def loader(self, model_path: str, from_pretrained_kwargs: dict): def loader(self, model_path: str, from_pretrained_kwargs: dict):
import gpt4all import gpt4all

View File

@ -1,23 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import threading
import sys
import time
def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings):
stop = params.get("stop", "###") stop = params.get("stop", "###")
prompt = params["prompt"] prompt = params["prompt"]
role, query = prompt.split(stop)[1].split(":") role, query = prompt.split(stop)[0].split(":")
print(f"gpt4all, role: {role}, query: {query}") print(f"gpt4all, role: {role}, query: {query}")
yield model.generate(prompt=query, streaming=True)
def worker():
model.generate(prompt=query, streaming=True)
t = threading.Thread(target=worker)
t.start()
while t.is_alive():
yield sys.stdout.output
time.sleep(0.01)
t.join()

View File

@ -148,28 +148,6 @@ class ChatGLMChatAdapter(BaseChatAdpter):
return chatglm_generate_stream return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeT5"""
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self, model_path: str):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
"""Model chat adapter for CodeGen"""
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self, model_path: str):
# TODO
pass
class GuanacoChatAdapter(BaseChatAdpter): class GuanacoChatAdapter(BaseChatAdpter):
"""Model chat adapter for Guanaco""" """Model chat adapter for Guanaco"""
@ -216,7 +194,7 @@ class GorillaChatAdapter(BaseChatAdpter):
class GPT4AllChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str): def match(self, model_path: str):
return "gpt4all" in model_path return "gptj-6b" in model_path
def get_generate_stream_func(self, model_path: str): def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream

View File

@ -90,6 +90,7 @@ class ModelWorker:
params, model_context = self.llm_chat_adapter.model_adaptation( params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template params, self.ml.model_path, prompt_template=self.ml.prompt_template
) )
for output in self.generate_stream_func( for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
): ):