mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-05 10:29:36 +00:00
Add: multi model support
This commit is contained in:
parent
a68e164a5f
commit
4302ae9087
96
pilot/model/adapter.py
Normal file
96
pilot/model/adapter.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoModel
|
||||||
|
)
|
||||||
|
|
||||||
|
class BaseLLMAdaper:
|
||||||
|
"""The Base class for multi model, in our project.
|
||||||
|
We will support those model, which performance resemble ChatGPT """
|
||||||
|
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
llm_model_adapters = List[BaseLLMAdaper] = []
|
||||||
|
|
||||||
|
# Register llm models to adapters, by this we can use multi models.
|
||||||
|
def register_llm_model_adapters(cls):
|
||||||
|
"""Register a llm model adapter."""
|
||||||
|
llm_model_adapters.append(cls())
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
|
||||||
|
for adapter in llm_model_adapters:
|
||||||
|
if adapter.match(model_path):
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid model adapter for {model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||||
|
|
||||||
|
class VicunaLLMAdapater(BaseLLMAdaper):
|
||||||
|
"""Vicuna Adapter """
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "vicuna" in model_path
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwagrs: dict):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
**from_pretrained_kwagrs
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
class ChatGLMAdapater(BaseLLMAdaper):
|
||||||
|
"""LLM Adatpter for THUDM/chatglm-6b"""
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "chatglm" in model_path
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||||
|
).half().cuda()
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
class KoalaLLMAdapter(BaseLLMAdaper):
|
||||||
|
"""Koala LLM Adapter which Based LLaMA """
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "koala" in model_path
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV4LLMAdapter(BaseLLMAdaper):
|
||||||
|
"""LLM Adapter for RwKv4 """
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "RWKV-4" in model_path
|
||||||
|
|
||||||
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GPT4AllAdapter(BaseLLMAdaper):
|
||||||
|
"""A light version for someone who want practise LLM use laptop."""
|
||||||
|
def match(self, model_path: str):
|
||||||
|
return "gpt4all" in model_path
|
||||||
|
|
||||||
|
|
||||||
|
register_llm_model_adapters(VicunaLLMAdapater)
|
||||||
|
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||||
|
|
||||||
|
register_llm_model_adapters(BaseLLMAdaper)
|
@ -2,21 +2,19 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModel
|
|
||||||
)
|
|
||||||
|
|
||||||
from pilot.model.compression import compress_module
|
from pilot.model.compression import compress_module
|
||||||
|
from pilot.model.adapter import get_llm_model_adapter
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader(metaclass=Singleton):
|
class ModelLoader(metaclass=Singleton):
|
||||||
"""Model loader is a class for model load
|
"""Model loader is a class for model load
|
||||||
|
|
||||||
Args: model_path
|
Args: model_path
|
||||||
|
|
||||||
|
TODO: multi model support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
@ -31,9 +29,11 @@ class ModelLoader(metaclass=Singleton):
|
|||||||
"device_map": "auto",
|
"device_map": "auto",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO multi gpu support
|
||||||
def loader(self, num_gpus, load_8bit=False, debug=False):
|
def loader(self, num_gpus, load_8bit=False, debug=False):
|
||||||
if self.device == "cpu":
|
if self.device == "cpu":
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
elif self.device == "cuda":
|
elif self.device == "cuda":
|
||||||
kwargs = {"torch_dtype": torch.float16}
|
kwargs = {"torch_dtype": torch.float16}
|
||||||
if num_gpus == "auto":
|
if num_gpus == "auto":
|
||||||
@ -46,18 +46,20 @@ class ModelLoader(metaclass=Singleton):
|
|||||||
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
"max_memory": {i: "13GiB" for i in range(num_gpus)},
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
|
# Todo Support mps for practise
|
||||||
raise ValueError(f"Invalid device: {self.device}")
|
raise ValueError(f"Invalid device: {self.device}")
|
||||||
|
|
||||||
if "chatglm" in self.model_path:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
llm_adapter = get_llm_model_adapter(self.model_path)
|
||||||
model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
|
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
|
||||||
else:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=False)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(self.model_path,
|
|
||||||
low_cpu_mem_usage=True, **kwargs)
|
|
||||||
|
|
||||||
if load_8bit:
|
if load_8bit:
|
||||||
compress_module(model, self.device)
|
if num_gpus != 1:
|
||||||
|
warnings.warn(
|
||||||
|
"8-bit quantization is not supported for multi-gpu inference"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
compress_module(model, self.device)
|
||||||
|
|
||||||
if (self.device == "cuda" and num_gpus == 1):
|
if (self.device == "cuda" and num_gpus == 1):
|
||||||
model.to(self.device)
|
model.to(self.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user