From 370e327bf3ebb037dc5e275fe58005c175463382 Mon Sep 17 00:00:00 2001 From: csunny Date: Sat, 20 May 2023 16:23:07 +0800 Subject: [PATCH] add chatglm model --- pilot/model/adapter.py | 1 + pilot/server/chat_adapter.py | 71 +++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 84cd699ac..bf0e291ce 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -100,6 +100,7 @@ class GPT4AllAdapter(BaseLLMAdaper): register_llm_model_adapters(VicunaLLMAdapater) +register_llm_model_adapters(ChatGLMAdapater) # TODO Default support vicuna, other model need to tests and Evaluate register_llm_model_adapters(BaseLLMAdaper) \ No newline at end of file diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 9c32c911d..ded0a1b19 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import List +from functools import cache +from pilot.model.inference import generate_stream class BaseChatAdpter: """The Base class for chat with llm models. it will match the model, @@ -10,4 +13,70 @@ class BaseChatAdpter: return True def get_generate_stream_func(self): - pass \ No newline at end of file + """Return the generate stream handler func""" + pass + + +llm_model_chat_adapters: List[BaseChatAdpter] = [] + + +def register_llm_model_chat_adapter(cls): + """Register a chat adapter""" + llm_model_chat_adapters.append(cls()) + + +@cache +def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: + """Get a chat generate func for a model""" + for adapter in llm_model_chat_adapters: + if adapter.match(model_path): + return adapter + + raise ValueError(f"Invalid model for chat adapter {model_path}") + + +class VicunaChatAdapter(BaseChatAdpter): + + """ Model chat Adapter for vicuna""" + def match(self, model_path: str): + return "vicuna" in model_path + + def get_generate_stream_func(self): + return generate_stream + + +class ChatGLMChatAdapter(BaseChatAdpter): + """ Model chat Adapter for ChatGLM""" + def match(self, model_path: str): + return "chatglm" in model_path + + def get_generate_stream_func(self): + from pilot.model.chatglm_llm import chatglm_generate_stream + return chatglm_generate_stream + + +class CodeT5ChatAdapter(BaseChatAdpter): + + """ Model chat adapter for CodeT5 """ + def match(self, model_path: str): + return "codet5" in model_path + + def get_generate_stream_func(self): + # TODO + pass + +class CodeGenChatAdapter(BaseChatAdpter): + + """ Model chat adapter for CodeGen """ + def match(self, model_path: str): + return "codegen" in model_path + + def get_generate_stream_func(self): + # TODO + pass + + +register_llm_model_chat_adapter(VicunaChatAdapter) +register_llm_model_chat_adapter(ChatGLMChatAdapter) + +register_llm_model_chat_adapter(BaseChatAdpter) \ No newline at end of file