mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 20:10:08 +00:00
Add: multi model support
This commit is contained in:
96
pilot/model/adapter.py
Normal file
96
pilot/model/adapter.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
from functools import cache
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoModel
|
||||
)
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
We will support those model, which performance resemble ChatGPT """
|
||||
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
llm_model_adapters = List[BaseLLMAdaper] = []
|
||||
|
||||
# Register llm models to adapters, by this we can use multi models.
|
||||
def register_llm_model_adapters(cls):
|
||||
"""Register a llm model adapter."""
|
||||
llm_model_adapters.append(cls())
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
|
||||
for adapter in llm_model_adapters:
|
||||
if adapter.match(model_path):
|
||||
return adapter
|
||||
|
||||
raise ValueError(f"Invalid model adapter for {model_path}")
|
||||
|
||||
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
class VicunaLLMAdapater(BaseLLMAdaper):
|
||||
"""Vicuna Adapter """
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
**from_pretrained_kwagrs
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
class ChatGLMAdapater(BaseLLMAdaper):
|
||||
"""LLM Adatpter for THUDM/chatglm-6b"""
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||
).half().cuda()
|
||||
return model, tokenizer
|
||||
|
||||
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||
"""Koala LLM Adapter which Based LLaMA """
|
||||
def match(self, model_path: str):
|
||||
return "koala" in model_path
|
||||
|
||||
|
||||
class RWKV4LLMAdapter(BaseLLMAdaper):
|
||||
"""LLM Adapter for RwKv4 """
|
||||
def match(self, model_path: str):
|
||||
return "RWKV-4" in model_path
|
||||
|
||||
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
class GPT4AllAdapter(BaseLLMAdaper):
|
||||
"""A light version for someone who want practise LLM use laptop."""
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
register_llm_model_adapters(BaseLLMAdaper)
|
Reference in New Issue
Block a user