mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-24 14:32:53 +00:00
add chatglm model
This commit is contained in:
@@ -100,6 +100,7 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
||||
@@ -1,6 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import List
|
||||
from functools import cache
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
class BaseChatAdpter:
|
||||
"""The Base class for chat with llm models. it will match the model,
|
||||
@@ -10,4 +13,70 @@ class BaseChatAdpter:
|
||||
return True
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
pass
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
|
||||
|
||||
llm_model_chat_adapters: List[BaseChatAdpter] = []
|
||||
|
||||
|
||||
def register_llm_model_chat_adapter(cls):
|
||||
"""Register a chat adapter"""
|
||||
llm_model_chat_adapters.append(cls())
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
"""Get a chat generate func for a model"""
|
||||
for adapter in llm_model_chat_adapters:
|
||||
if adapter.match(model_path):
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model for chat adapter {model_path}")
|
||||
|
||||
|
||||
class VicunaChatAdapter(BaseChatAdpter):
|
||||
|
||||
""" Model chat Adapter for vicuna"""
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
return generate_stream
|
||||
|
||||
|
||||
class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
""" Model chat Adapter for ChatGLM"""
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.chatglm_llm import chatglm_generate_stream
|
||||
return chatglm_generate_stream
|
||||
|
||||
|
||||
class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
|
||||
""" Model chat adapter for CodeT5 """
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
class CodeGenChatAdapter(BaseChatAdpter):
|
||||
|
||||
""" Model chat adapter for CodeGen """
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
register_llm_model_chat_adapter(ChatGLMChatAdapter)
|
||||
|
||||
register_llm_model_chat_adapter(BaseChatAdpter)
|
||||
Reference in New Issue
Block a user