add chatglm model

This commit is contained in:
csunny
2023-05-20 16:23:07 +08:00
parent cbf1d0662a
commit 370e327bf3
2 changed files with 71 additions and 1 deletions

View File

@@ -100,6 +100,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate
register_llm_model_adapters(BaseLLMAdaper)

View File

@@ -1,6 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List
from functools import cache
from pilot.model.inference import generate_stream
class BaseChatAdpter:
"""The Base class for chat with llm models. it will match the model,
@@ -10,4 +13,70 @@ class BaseChatAdpter:
return True
def get_generate_stream_func(self):
pass
"""Return the generate stream handler func"""
pass
llm_model_chat_adapters: List[BaseChatAdpter] = []
def register_llm_model_chat_adapter(cls):
"""Register a chat adapter"""
llm_model_chat_adapters.append(cls())
@cache
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
if adapter.match(model_path):
return adapter
raise ValueError(f"Invalid model for chat adapter {model_path}")
class VicunaChatAdapter(BaseChatAdpter):
""" Model chat Adapter for vicuna"""
def match(self, model_path: str):
return "vicuna" in model_path
def get_generate_stream_func(self):
return generate_stream
class ChatGLMChatAdapter(BaseChatAdpter):
""" Model chat Adapter for ChatGLM"""
def match(self, model_path: str):
return "chatglm" in model_path
def get_generate_stream_func(self):
from pilot.model.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
class CodeT5ChatAdapter(BaseChatAdpter):
""" Model chat adapter for CodeT5 """
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self):
# TODO
pass
class CodeGenChatAdapter(BaseChatAdpter):
""" Model chat adapter for CodeGen """
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self):
# TODO
pass
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(BaseChatAdpter)