From 545c3232161b11a1e6dfccaa3bb393718a745bc5 Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 11:35:06 +0800 Subject: [PATCH 1/7] add gpt4all --- pilot/configs/model_config.py | 1 + pilot/model/adapter.py | 23 +++++++++++++++++------ pilot/model/llm_out/gpt4all_llm.py | 17 +++++++++++++++++ pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 15 ++++++++++++--- pilot/server/llmserver.py | 5 +++-- 6 files changed, 51 insertions(+), 12 deletions(-) create mode 100644 pilot/model/llm_out/gpt4all_llm.py diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 36d615043..b85fe6b7b 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,6 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), + "ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 7892e4b1b..89ea55ec2 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import torch +import os +from functools import cache from typing import List from functools import cache from transformers import ( @@ -92,8 +94,8 @@ class ChatGLMAdapater(BaseLLMAdaper): AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) - .half() - .cuda() + .half() + .cuda() ) return model, tokenizer @@ -185,18 +187,26 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): - """A light version for someone who want practise LLM use laptop.""" + """ + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json + """ def match(self, model_path: str): return "gpt4all" in model_path def loader(self, model_path: str, from_pretrained_kwargs: dict): - # TODO - pass + import gpt4all + + if model_path is None and from_pretrained_kwargs.get('model_name') is None: + model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") + else: + path, file = os.path.split(model_path) + model = gpt4all.GPT4All(model_path=path, model_name=file) + return model, None class ProxyllmAdapter(BaseLLMAdaper): - """The model adapter for local proxy""" def match(self, model_path: str): @@ -211,6 +221,7 @@ register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) register_llm_model_adapters(FalconAdapater) register_llm_model_adapters(GorillaAdapter) +register_llm_model_adapters(GPT4AllAdapter) # TODO Default support vicuna, other model need to tests and Evaluate # just for test, remove this later diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py new file mode 100644 index 000000000..4cc1f067f --- /dev/null +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- + +def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): + stop = params.get("stop", "###") + prompt = params["prompt"] + role, query = prompt.split(stop)[1].split(":") + print(f"gpt4all, role: {role}, query: {query}") + + messages = [{"role": "user", "content": query}] + res = model.chat_completion(messages) + if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ + res.get('choices')[0].get('message').get('content'): + yield res.get('choices')[0].get('message').get('content') + else: + yield "error response" + diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 513c1d300..6f08d93fe 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data["error_code"] == 0: + if data.get('error_code', 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index e4f57cf46..3598b16b3 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -37,7 +37,6 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: class VicunaChatAdapter(BaseChatAdpter): - """Model chat Adapter for vicuna""" def match(self, model_path: str): @@ -60,7 +59,6 @@ class ChatGLMChatAdapter(BaseChatAdpter): class CodeT5ChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeT5""" def match(self, model_path: str): @@ -72,7 +70,6 @@ class CodeT5ChatAdapter(BaseChatAdpter): class CodeGenChatAdapter(BaseChatAdpter): - """Model chat adapter for CodeGen""" def match(self, model_path: str): @@ -127,11 +124,23 @@ class GorillaChatAdapter(BaseChatAdpter): return generate_stream +class GPT4AllChatAdapter(BaseChatAdpter): + + def match(self, model_path: str): + return "gpt4all" in model_path + + def get_generate_stream_func(self): + from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream + + return gpt4all_generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) register_llm_model_chat_adapter(FalconChatAdapter) register_llm_model_chat_adapter(GorillaChatAdapter) +register_llm_model_chat_adapter(GPT4AllChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index d2730e0d5..e71872d64 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,9 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): self.context_len = self.model.config.max_position_embeddings else: @@ -66,6 +66,7 @@ class ModelWorker: def generate_stream_gate(self, params): try: + print(f"llmserver params: {params}, self: {self}") for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): From cacab625cf1b3ea09f83233816d7dd996a24919a Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 11:39:51 +0800 Subject: [PATCH 2/7] rm log --- pilot/model/adapter.py | 1 - pilot/server/llmserver.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 89ea55ec2..435be1142 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,7 +3,6 @@ import torch import os -from functools import cache from typing import List from functools import cache from transformers import ( diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index e71872d64..ad4627afa 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -66,7 +66,6 @@ class ModelWorker: def generate_stream_gate(self, params): try: - print(f"llmserver params: {params}, self: {self}") for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): From 7136aa748dce63c2f1235d9606c2686d0437cc0f Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 13:57:47 +0800 Subject: [PATCH 3/7] fix model key --- pilot/configs/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index b85fe6b7b..851a0486d 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -37,7 +37,7 @@ LLM_MODEL_CONFIG = { "guanaco-33b-merged": os.path.join(MODEL_PATH, "guanaco-33b-merged"), "falcon-40b": os.path.join(MODEL_PATH, "falcon-40b"), "gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"), - "ggml-gpt4all-j-v1.3-groovy": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), + "gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"), "proxyllm": "proxyllm", } From fc4fd6aa773df80d4234629156d2ad3750851024 Mon Sep 17 00:00:00 2001 From: ykgong Date: Fri, 9 Jun 2023 14:05:49 +0800 Subject: [PATCH 4/7] requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 9238751ca..c6434c3ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,6 +49,7 @@ llama-index==0.5.27 pymysql unstructured==0.6.3 grpcio==1.47.5 +gpt4all==0.3.0 auto-gpt-plugin-template pymdown-extensions From a88e8aa51bf27bf5540655d73bd579cf7cfae7bd Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:11:23 +0800 Subject: [PATCH 5/7] update stream output --- pilot/model/adapter.py | 6 +++--- pilot/model/llm_out/gpt4all_llm.py | 20 +++++++++++++------- pilot/out_parser/base.py | 2 +- pilot/server/chat_adapter.py | 1 - pilot/server/llmserver.py | 11 ++++++++--- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 435be1142..407d11127 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -187,8 +187,8 @@ class RWKV4LLMAdapter(BaseLLMAdaper): class GPT4AllAdapter(BaseLLMAdaper): """ - A light version for someone who want practise LLM use laptop. - All model names see: https://gpt4all.io/models/models.json + A light version for someone who want practise LLM use laptop. + All model names see: https://gpt4all.io/models/models.json """ def match(self, model_path: str): @@ -197,7 +197,7 @@ class GPT4AllAdapter(BaseLLMAdaper): def loader(self, model_path: str, from_pretrained_kwargs: dict): import gpt4all - if model_path is None and from_pretrained_kwargs.get('model_name') is None: + if model_path is None and from_pretrained_kwargs.get("model_name") is None: model = gpt4all.GPT4All("ggml-gpt4all-j-v1.3-groovy") else: path, file = os.path.split(model_path) diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 4cc1f067f..5ea72f911 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- +import threading +import sys +import time + def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): stop = params.get("stop", "###") @@ -7,11 +11,13 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed role, query = prompt.split(stop)[1].split(":") print(f"gpt4all, role: {role}, query: {query}") - messages = [{"role": "user", "content": query}] - res = model.chat_completion(messages) - if res.get('choices') and len(res.get('choices')) > 0 and res.get('choices')[0].get('message') and \ - res.get('choices')[0].get('message').get('content'): - yield res.get('choices')[0].get('message').get('content') - else: - yield "error response" + def worker(): + model.generate(prompt=query, streaming=True) + t = threading.Thread(target=worker) + t.start() + + while t.is_alive(): + yield sys.stdout.output + time.sleep(0.1) + t.join() diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 6f08d93fe..6406f30dd 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -51,7 +51,7 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ - if data.get('error_code', 0) == 0: + if data.get("error_code", 0) == 0: if "vicuna" in CFG.LLM_MODEL: # output = data["text"][skip_echo_len + 11:].strip() output = data["text"][skip_echo_len:].strip() diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 3598b16b3..ebab2d2d4 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -125,7 +125,6 @@ class GorillaChatAdapter(BaseChatAdpter): class GPT4AllChatAdapter(BaseChatAdpter): - def match(self, model_path: str): return "gpt4all" in model_path diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index ad4627afa..66180a406 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -39,9 +39,13 @@ class ModelWorker: ) if not isinstance(self.model, str): - if hasattr(self.model, "config") and hasattr(self.model.config, "max_sequence_length"): + if hasattr(self.model, "config") and hasattr( + self.model.config, "max_sequence_length" + ): self.context_len = self.model.config.max_sequence_length - elif hasattr(self.model, "config") and hasattr(self.model.config, "max_position_embeddings"): + elif hasattr(self.model, "config") and hasattr( + self.model.config, "max_position_embeddings" + ): self.context_len = self.model.config.max_position_embeddings else: @@ -69,7 +73,8 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - print("output: ", output) + # 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出 + # print("output: ", output) ret = { "text": output, "error_code": 0, From 5da4b38964c8662f61e2e409cf1476ae4839d346 Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:22:55 +0800 Subject: [PATCH 6/7] format code --- pilot/model/adapter.py | 4 ++-- pilot/server/llmserver.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 407d11127..01d05837b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -93,8 +93,8 @@ class ChatGLMAdapater(BaseLLMAdaper): AutoModel.from_pretrained( model_path, trust_remote_code=True, **from_pretrained_kwargs ) - .half() - .cuda() + .half() + .cuda() ) return model, tokenizer diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 66180a406..30653a16e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -73,7 +73,9 @@ class ModelWorker: for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): - # 生产请不要打开输出!gpt4all线程与父进程共享stdout, 打开会影响前端输出 + # Please do not open the output in production! + # The gpt4all thread shares stdout with the parent process, + # and opening it may affect the frontend output. # print("output: ", output) ret = { "text": output, From cc3c2d779936e471ca76e85424c0754cd040c90f Mon Sep 17 00:00:00 2001 From: sheri528 Date: Tue, 13 Jun 2023 14:35:14 +0800 Subject: [PATCH 7/7] update sleep interval --- pilot/model/llm_out/gpt4all_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 5ea72f911..7a39a8012 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -19,5 +19,5 @@ def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embed while t.is_alive(): yield sys.stdout.output - time.sleep(0.1) + time.sleep(0.01) t.join()