diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 31684d6a7..05f9ffdcb 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -232,6 +232,7 @@ class GorillaAdapter(BaseLLMAdaper): ) return model, tokenizer + class StarCoderAdapter(BaseLLMAdaper): pass diff --git a/pilot/model/llm_out/gpt4all_llm.py b/pilot/model/llm_out/gpt4all_llm.py index 4954a9933..3ea6b8206 100644 --- a/pilot/model/llm_out/gpt4all_llm.py +++ b/pilot/model/llm_out/gpt4all_llm.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 # -*- coding:utf-8 -*- + def gpt4all_generate_stream(model, tokenizer, params, device, max_position_embeddings): stop = params.get("stop", "###") prompt = params["prompt"] role, query = prompt.split(stop)[0].split(":") print(f"gpt4all, role: {role}, query: {query}") yield model.generate(prompt=query, streaming=True) - diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index d2af0733e..0bc53e8fd 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -191,6 +191,7 @@ class GorillaChatAdapter(BaseChatAdpter): return generate_stream + class GPT4AllChatAdapter(BaseChatAdpter): def match(self, model_path: str): return "gptj-6b" in model_path @@ -199,7 +200,7 @@ class GPT4AllChatAdapter(BaseChatAdpter): from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream return gpt4all_generate_stream - + class Llama2ChatAdapter(BaseChatAdpter): def match(self, model_path: str):