diff --git a/.env.template b/.env.template index f53bf31f7..d3187d0a1 100644 --- a/.env.template +++ b/.env.template @@ -37,6 +37,14 @@ QUANTIZE_8bit=True ## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. # PROXYLLM_BACKEND= +### You can configure parameters for a specific model with {model name}_{config key}=xxx +### See /pilot/model/parameter.py +## prompt template for current model +# llama_cpp_prompt_template=vicuna_v1.1 +## llama-2-70b must be 8 +# llama_cpp_n_gqa=8 +## Model path +# llama_cpp_model_path=/data/models/TheBloke/vicuna-7B-v1.5-GGML/vicuna-7b-v1.5.ggmlv3.q4_0.bin #*******************************************************************# #** EMBEDDING SETTINGS **# diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index 7df9e2d4e..4075444a9 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -11,19 +11,30 @@ ARG LANGUAGE="en" ARG PIP_INDEX_URL="https://pypi.org/simple" ENV PIP_INDEX_URL=$PIP_INDEX_URL +RUN mkdir -p /app + # COPY only requirements.txt first to leverage Docker cache -COPY ./requirements.txt /tmp/requirements.txt +COPY ./requirements.txt /app/requirements.txt +COPY ./setup.py /app/setup.py +COPY ./README.md /app/README.md WORKDIR /app +# RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \ +# && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \ +# # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip +# then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \ +# && cp /app/requirements.txt /tmp/requirements.txt; \ +# fi;) \ +# && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \ +# && rm /tmp/requirements.txt + RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \ - && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \ - # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip - then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \ - && cp /app/requirements.txt /tmp/requirements.txt; \ - fi;) \ - && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \ - && rm /tmp/requirements.txt + && cd /app && pip3 install -i $PIP_INDEX_URL . + +# ENV CMAKE_ARGS="-DLLAMA_CUBLAS=ON -DLLAMA_AVX2=OFF -DLLAMA_F16C=OFF -DLLAMA_FMA=OFF" +# ENV FORCE_CMAKE=1 +RUN cd /app && pip3 install -i $PIP_INDEX_URL .[llama_cpp] RUN (if [ "${LANGUAGE}" = "zh" ]; \ # language is zh, download zh_core_web_sm from github @@ -37,12 +48,11 @@ RUN (if [ "${LANGUAGE}" = "zh" ]; \ ARG BUILD_LOCAL_CODE="false" # COPY the rest of the app -COPY . /tmp/app +COPY . /app # TODO:Need to find a better way to determine whether to build docker image with local code. RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \ - then mv /tmp/app / && rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \ - else rm -rf /tmp/app; \ + then rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \ fi;) ARG LOAD_EXAMPLES="true" diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 9ca56824a..044572ca9 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -64,6 +64,7 @@ LLM_MODEL_CONFIG = { "baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"), # (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2 "wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"), + "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"), } # Load model config diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 1c8562e78..d97d2cb2b 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -3,7 +3,9 @@ import torch import os -from typing import List +import re +from pathlib import Path +from typing import List, Tuple from functools import cache from transformers import ( AutoModel, @@ -12,8 +14,10 @@ from transformers import ( LlamaTokenizer, BitsAndBytesConfig, ) +from pilot.model.parameter import ModelParameters, LlamaCppModelParameters from pilot.configs.model_config import DEVICE from pilot.configs.config import Config +from pilot.logs import logger bnb_config = BitsAndBytesConfig( load_in_4bit=True, @@ -24,6 +28,14 @@ bnb_config = BitsAndBytesConfig( CFG = Config() +class ModelType: + """ "Type of model""" + + HF = "huggingface" + LLAMA_CPP = "llama.cpp" + # TODO, support more model type + + class BaseLLMAdaper: """The Base class for multi model, in our project. We will support those model, which performance resemble ChatGPT""" @@ -31,8 +43,17 @@ class BaseLLMAdaper: def use_fast_tokenizer(self) -> bool: return False + def model_type(self) -> str: + return ModelType.HF + + def model_param_class(self, model_type: str = None) -> ModelParameters: + model_type = model_type if model_type else self.model_type() + if model_type == ModelType.LLAMA_CPP: + return LlamaCppModelParameters + return ModelParameters + def match(self, model_path: str): - return True + return False def loader(self, model_path: str, from_pretrained_kwargs: dict): tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) @@ -52,12 +73,25 @@ def register_llm_model_adapters(cls): @cache -def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper: +def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: + # Prefer using model name matching for adapter in llm_model_adapters: - if adapter.match(model_path): + if adapter.match(model_name): + logger.info( + f"Found llm model adapter with model name: {model_name}, {adapter}" + ) return adapter - raise ValueError(f"Invalid model adapter for {model_path}") + for adapter in llm_model_adapters: + if adapter.match(model_path): + logger.info( + f"Found llm model adapter with model path: {model_path}, {adapter}" + ) + return adapter + + raise ValueError( + f"Invalid model adapter for model name {model_name} and model path {model_path}" + ) # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? @@ -296,6 +330,52 @@ class WizardLMAdapter(BaseLLMAdaper): return "wizardlm" in model_path.lower() +class LlamaCppAdapater(BaseLLMAdaper): + @staticmethod + def _parse_model_path(model_path: str) -> Tuple[bool, str]: + path = Path(model_path) + if not path.exists(): + # Just support local model + return False, None + if not path.is_file(): + model_paths = list(path.glob("*ggml*.bin")) + if not model_paths: + return False + model_path = str(model_paths[0]) + logger.warn( + f"Model path {model_path} is not single file, use first *gglm*.bin model file: {model_path}" + ) + if not re.fullmatch(".*ggml.*\.bin", model_path): + return False, None + return True, model_path + + def model_type(self) -> ModelType: + return ModelType.LLAMA_CPP + + def match(self, model_path: str): + """ + https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML + """ + if "llama-cpp" == model_path: + return True + is_match, _ = LlamaCppAdapater._parse_model_path(model_path) + return is_match + + def loader(self, model_path: str, from_pretrained_kwargs: dict): + # TODO not support yet + _, model_path = LlamaCppAdapater._parse_model_path(model_path) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, use_fast=False + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + low_cpu_mem_usage=True, + **from_pretrained_kwargs, + ) + return model, tokenizer + + register_llm_model_adapters(VicunaLLMAdapater) register_llm_model_adapters(ChatGLMAdapater) register_llm_model_adapters(GuanacoAdapter) @@ -305,6 +385,7 @@ register_llm_model_adapters(GPT4AllAdapter) register_llm_model_adapters(Llama2Adapter) register_llm_model_adapters(BaichuanAdapter) register_llm_model_adapters(WizardLMAdapter) +register_llm_model_adapters(LlamaCppAdapater) # TODO Default support vicuna, other model need to tests and Evaluate # just for test_py, remove this later diff --git a/pilot/model/llm/llama_cpp/__init__.py b/pilot/model/llm/llama_cpp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/model/llm/llama_cpp/llama_cpp.py b/pilot/model/llm/llama_cpp/llama_cpp.py new file mode 100644 index 000000000..5029d9192 --- /dev/null +++ b/pilot/model/llm/llama_cpp/llama_cpp.py @@ -0,0 +1,145 @@ +""" +Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py +""" +import re +from typing import Dict, Any +import torch +import llama_cpp + +from pilot.model.parameter import LlamaCppModelParameters +from pilot.logs import logger + +if torch.cuda.is_available() and not torch.version.hip: + try: + import llama_cpp_cuda + except: + llama_cpp_cuda = None +else: + llama_cpp_cuda = None + + +def llama_cpp_lib(prefer_cpu: bool = False): + if prefer_cpu or llama_cpp_cuda is None: + logger.info(f"Llama.cpp use cpu") + return llama_cpp + else: + return llama_cpp_cuda + + +def ban_eos_logits_processor(eos_token, input_ids, logits): + logits[eos_token] = -float("inf") + return logits + + +def get_params(model_path: str, model_params: LlamaCppModelParameters) -> Dict: + return { + "model_path": model_path, + "n_ctx": model_params.max_context_size, + "seed": model_params.seed, + "n_threads": model_params.n_threads, + "n_batch": model_params.n_batch, + "use_mmap": True, + "use_mlock": False, + "low_vram": False, + "n_gpu_layers": 0 if model_params.prefer_cpu else model_params.n_gpu_layers, + "n_gqa": model_params.n_gqa, + "logits_all": True, + "rms_norm_eps": model_params.rms_norm_eps, + } + + +class LlamaCppModel: + def __init__(self): + self.initialized = False + self.model = None + self.verbose = True + + def __del__(self): + if self.model: + self.model.__del__() + + @classmethod + def from_pretrained(self, model_path, model_params: LlamaCppModelParameters): + Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama + LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache + + result = self() + cache_capacity = 0 + cache_capacity_str = model_params.cache_capacity + if cache_capacity_str is not None: + if "GiB" in cache_capacity_str: + cache_capacity = ( + int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000 * 1000 + ) + elif "MiB" in cache_capacity_str: + cache_capacity = ( + int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000 + ) + else: + cache_capacity = int(cache_capacity_str) + + params = get_params(model_path, model_params) + logger.info("Cache capacity is " + str(cache_capacity) + " bytes") + logger.info(f"Load LLama model with params: {params}") + + result.model = Llama(**params) + result.verbose = model_params.verbose + if cache_capacity > 0: + result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity)) + + # This is ugly, but the model and the tokenizer are the same object in this library. + return result, result + + def encode(self, string): + if type(string) is str: + string = string.encode() + + return self.model.tokenize(string) + + def decode(self, tokens): + return self.model.detokenize(tokens) + + def generate_streaming(self, params, context_len: int): + # LogitsProcessorList = llama_cpp_lib().LogitsProcessorList + + # Read parameters + prompt = params["prompt"] + if self.verbose: + print(f"Prompt of model: \n{prompt}") + + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.1)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_new_tokens", 2048)) + echo = bool(params.get("echo", True)) + + max_src_len = context_len - max_new_tokens + # Handle truncation + prompt = self.encode(prompt) + prompt = prompt[-max_src_len:] + prompt = self.decode(prompt).decode("utf-8") + + # TODO Compared with the original llama model, the Chinese effect of llama.cpp is very general, and it needs to be debugged + completion_chunks = self.model.create_completion( + prompt=prompt, + max_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + repeat_penalty=repetition_penalty, + # tfs_z=params['tfs'], + # mirostat_mode=int(params['mirostat_mode']), + # mirostat_tau=params['mirostat_tau'], + # mirostat_eta=params['mirostat_eta'], + stream=True, + echo=echo, + logits_processor=None, + ) + + output = "" + for completion_chunk in completion_chunks: + text = completion_chunk["choices"][0]["text"] + output += text + # print(output) + yield output diff --git a/pilot/model/llm_out/llama_cpp_llm.py b/pilot/model/llm_out/llama_cpp_llm.py new file mode 100644 index 000000000..921670065 --- /dev/null +++ b/pilot/model/llm_out/llama_cpp_llm.py @@ -0,0 +1,8 @@ +from typing import Dict +import torch + + +@torch.inference_mode() +def generate_stream(model, tokenizer, params: Dict, device: str, context_len: int): + # Just support LlamaCppModel + return model.generate_streaming(params=params, context_len=context_len) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index c1af72e6e..170f7c460 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -2,48 +2,24 @@ # -*- coding: utf-8 -*- from typing import Optional, Dict -import dataclasses import torch from pilot.configs.model_config import DEVICE -from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper +from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.compression import compress_module +from pilot.model.parameter import ( + EnvArgumentParser, + ModelParameters, + LlamaCppModelParameters, + _genenv_ignoring_key_case, +) from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations from pilot.singleton import Singleton from pilot.utils import get_gpu_memory from pilot.logs import logger -class ModelType: - """ "Type of model""" - - HF = "huggingface" - LLAMA_CPP = "llama.cpp" - # TODO, support more model type - - -@dataclasses.dataclass -class ModelParams: - device: str - model_name: str - model_path: str - model_type: Optional[str] = ModelType.HF - num_gpus: Optional[int] = None - max_gpu_memory: Optional[str] = None - cpu_offloading: Optional[bool] = False - load_8bit: Optional[bool] = True - load_4bit: Optional[bool] = False - # quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float) - quant_type: Optional[str] = "nf4" - # Nested quantization is activated through `use_double_quant`` - use_double_quant: Optional[bool] = True - # "bfloat16", "float16", "float32" - compute_dtype: Optional[str] = None - debug: Optional[bool] = False - trust_remote_code: Optional[bool] = True - - -def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams): +def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): # TODO: vicuna-v1.5 8-bit quantization info is slow # TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5 model_name = model_params.model_name.lower() @@ -51,7 +27,7 @@ def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams): return any(m in model_name for m in supported_models) -def _check_quantization(model_params: ModelParams): +def _check_quantization(model_params: ModelParameters): model_name = model_params.model_name.lower() has_quantization = any([model_params.load_8bit or model_params.load_4bit]) if has_quantization: @@ -69,6 +45,21 @@ def _check_quantization(model_params: ModelParams): return has_quantization +def _get_model_real_path(model_name, default_model_path) -> str: + """Get model real path by model name + priority from high to low: + 1. environment variable with key: {model_name}_model_path + 2. environment variable with key: model_path + 3. default_model_path + """ + env_prefix = model_name + "_" + env_prefix = env_prefix.replace("-", "_") + env_model_path = _genenv_ignoring_key_case("model_path", env_prefix=env_prefix) + if env_model_path: + return env_model_path + return _genenv_ignoring_key_case("model_path", default_value=default_model_path) + + class ModelLoader(metaclass=Singleton): """Model loader is a class for model load @@ -83,6 +74,7 @@ class ModelLoader(metaclass=Singleton): self.device = DEVICE self.model_path = model_path self.model_name = model_name + self.prompt_template: str = None self.kwargs = { "torch_dtype": torch.float16, "device_map": "auto", @@ -97,7 +89,18 @@ class ModelLoader(metaclass=Singleton): cpu_offloading=False, max_gpu_memory: Optional[str] = None, ): - model_params = ModelParams( + llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + model_type = llm_adapter.model_type() + param_cls = llm_adapter.model_param_class(model_type) + + args_parser = EnvArgumentParser() + # Read the parameters of the model from the environment variable according to the model name prefix, which currently has the highest priority + # vicuna_13b_max_gpu_memory=13Gib or VICUNA_13B_MAX_GPU_MEMORY=13Gib + env_prefix = self.model_name + "_" + env_prefix = env_prefix.replace("-", "_") + model_params = args_parser.parse_args_into_dataclass( + param_cls, + env_prefix=env_prefix, device=self.device, model_path=self.model_path, model_name=self.model_name, @@ -105,14 +108,21 @@ class ModelLoader(metaclass=Singleton): cpu_offloading=cpu_offloading, load_8bit=load_8bit, load_4bit=load_4bit, - debug=debug, + verbose=debug, ) + self.prompt_template = model_params.prompt_template + logger.info(f"model_params:\n{model_params}") - llm_adapter = get_llm_model_adapter(model_params.model_path) - return huggingface_loader(llm_adapter, model_params) + + if model_type == ModelType.HF: + return huggingface_loader(llm_adapter, model_params) + elif model_type == ModelType.LLAMA_CPP: + return llamacpp_loader(llm_adapter, model_params) + else: + raise Exception(f"Unkown model type {model_type}") -def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): +def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters): device = model_params.device max_memory = None @@ -175,14 +185,14 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): pass except AttributeError: pass - if model_params.debug: + if model_params.verbose: print(model) return model, tokenizer def load_huggingface_quantization_model( llm_adapter: BaseLLMAdaper, - model_params: ModelParams, + model_params: ModelParameters, kwargs: Dict, max_memory: Dict[int, str], ): @@ -312,3 +322,17 @@ def load_huggingface_quantization_model( ) return model, tokenizer + + +def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters): + try: + from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel + except ImportError as exc: + raise ValueError( + "Could not import python package: llama-cpp-python " + "Please install db-gpt llama support with `cd $DB-GPT-DIR && pip install .[llama_cpp]` " + "or install llama-cpp-python with `pip install llama-cpp-python`" + ) from exc + model_path = model_params.model_path + model, tokenizer = LlamaCppModel.from_pretrained(model_path, model_params) + return model, tokenizer diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py new file mode 100644 index 000000000..6ede45d9b --- /dev/null +++ b/pilot/model/parameter.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import os + +from typing import Any, Optional, Type +from dataclasses import dataclass, field, fields + +from pilot.model.conversation import conv_templates + +suported_prompt_templates = ",".join(conv_templates.keys()) + + +def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None): + """Get the value from the environment variable, ignoring the case of the key""" + if env_prefix: + env_key = env_prefix + env_key + return os.getenv( + env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value)) + ) + + +class EnvArgumentParser: + def parse_args_into_dataclass( + self, dataclass_type: Type, env_prefix: str = None, **kwargs + ) -> Any: + for field in fields(dataclass_type): + env_var_value = _genenv_ignoring_key_case(field.name, env_prefix) + if env_var_value: + env_var_value = env_var_value.strip() + if field.type is int or field.type == Optional[int]: + env_var_value = int(env_var_value) + elif field.type is float or field.type == Optional[float]: + env_var_value = float(env_var_value) + elif field.type is bool or field.type == Optional[bool]: + env_var_value = env_var_value.lower() == "true" + elif field.type is str or field.type == Optional[str]: + pass + else: + raise ValueError(f"Unsupported parameter type {field.type}") + kwargs[field.name] = env_var_value + return dataclass_type(**kwargs) + + +@dataclass +class ModelParameters: + device: str = field(metadata={"help": "Device to run model"}) + model_name: str = field(metadata={"help": "Model name"}) + model_path: str = field(metadata={"help": "Model path"}) + model_type: Optional[str] = field( + default="huggingface", metadata={"help": "Model type, huggingface or llama.cpp"} + ) + prompt_template: Optional[str] = field( + default=None, + metadata={ + "help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}" + }, + ) + max_context_size: Optional[int] = field( + default=4096, metadata={"help": "Maximum context size"} + ) + + num_gpus: Optional[int] = field( + default=None, + metadata={ + "help": "The number of gpus you expect to use, if it is empty, use all of them as much as possible" + }, + ) + max_gpu_memory: Optional[str] = field( + default=None, + metadata={ + "help": "The maximum memory limit of each GPU, only valid in multi-GPU configuration" + }, + ) + cpu_offloading: Optional[bool] = field( + default=False, metadata={"help": "CPU offloading"} + ) + load_8bit: Optional[bool] = field( + default=False, metadata={"help": "8-bit quantization"} + ) + load_4bit: Optional[bool] = field( + default=False, metadata={"help": "4-bit quantization"} + ) + quant_type: Optional[str] = field( + default="nf4", + metadata={ + "valid_values": ["nf4", "fp4"], + "help": "Quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float), only valid when load_4bit=True", + }, + ) + use_double_quant: Optional[bool] = field( + default=True, + metadata={"help": "Nested quantization, only valid when load_4bit=True"}, + ) + # "bfloat16", "float16", "float32" + compute_dtype: Optional[str] = field( + default=None, + metadata={ + "valid_values": ["bfloat16", "float16", "float32"], + "help": "Model compute type", + }, + ) + trust_remote_code: Optional[bool] = field( + default=True, metadata={"help": "Trust remote code"} + ) + verbose: Optional[bool] = field( + default=False, metadata={"help": "Show verbose output."} + ) + + +@dataclass +class LlamaCppModelParameters(ModelParameters): + seed: Optional[int] = field( + default=-1, metadata={"help": "Random seed for llama-cpp models. -1 for random"} + ) + n_threads: Optional[int] = field( + default=None, + metadata={ + "help": "Number of threads to use. If None, the number of threads is automatically determined" + }, + ) + n_batch: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum number of prompt tokens to batch together when calling llama_eval" + }, + ) + n_gpu_layers: Optional[int] = field( + default=1000000000, + metadata={ + "help": "Number of layers to offload to the GPU, Set this to 1000000000 to offload all layers to the GPU." + }, + ) + n_gqa: Optional[int] = field( + default=None, + metadata={"help": "Grouped-query attention. Must be 8 for llama-2 70b."}, + ) + rms_norm_eps: Optional[float] = field( + default=5e-06, metadata={"help": "5e-6 is a good value for llama-2 models."} + ) + cache_capacity: Optional[str] = field( + default=None, + metadata={ + "help": "Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. " + }, + ) + prefer_cpu: Optional[bool] = field( + default=False, + metadata={ + "help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured." + }, + ) diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index 00bdc4179..7f799f585 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -29,6 +29,8 @@ class ChatDashboardOutputParser(BaseOutputParser): print("clean prompt response:", clean_str) response = json.loads(clean_str) chart_items: List[ChartItem] = [] + if not isinstance(response, list): + response = [response] for item in response: chart_items.append( ChartItem( diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 72c429ea7..8f4b85385 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -20,7 +20,7 @@ The output data of the analysis cannot exceed 4 columns, and do not use columns According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type: {supported_chat_type} -Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens +Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title, display method and summary of brief analysis thinking, and respond in the following json format: {response} diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index ab3aec94e..07b44b28c 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -13,7 +13,7 @@ class BaseChatAdpter: and fetch output from model""" def match(self, model_path: str): - return True + return False def get_generate_stream_func(self, model_path: str): """Return the generate stream handler func""" @@ -24,7 +24,9 @@ class BaseChatAdpter: def get_conv_template(self, model_path: str) -> Conversation: return None - def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]: + def model_adaptation( + self, params: Dict, model_path: str, prompt_template: str = None + ) -> Tuple[Dict, Dict]: """Params adaptation""" conv = self.get_conv_template(model_path) messages = params.get("messages") @@ -39,6 +41,10 @@ class BaseChatAdpter: ] params["messages"] = messages + if prompt_template: + print(f"Use prompt template {prompt_template} from config") + conv = get_conv_template(prompt_template) + if not conv or not messages: # Nothing to do print( @@ -94,14 +100,19 @@ def register_llm_model_chat_adapter(cls): @cache -def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter: +def get_llm_chat_adapter(model_name: str, model_path: str) -> BaseChatAdpter: """Get a chat generate func for a model""" for adapter in llm_model_chat_adapters: - if adapter.match(model_path): - print(f"Get model path: {model_path} adapter {adapter}") + if adapter.match(model_name): + print(f"Get model chat adapter with model name {model_name}, {adapter}") return adapter - - raise ValueError(f"Invalid model for chat adapter {model_path}") + for adapter in llm_model_chat_adapters: + if adapter.match(model_path): + print(f"Get model chat adapter with model path {model_path}, {adapter}") + return adapter + raise ValueError( + f"Invalid model for chat adapter with model name {model_name} and model path {model_path}" + ) class VicunaChatAdapter(BaseChatAdpter): @@ -239,6 +250,24 @@ class WizardLMChatAdapter(BaseChatAdpter): return get_conv_template("vicuna_v1.1") +class LlamaCppChatAdapter(BaseChatAdpter): + def match(self, model_path: str): + from pilot.model.adapter import LlamaCppAdapater + + if "llama-cpp" == model_path: + return True + is_match, _ = LlamaCppAdapater._parse_model_path(model_path) + return is_match + + def get_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("llama-2") + + def get_generate_stream_func(self, model_path: str): + from pilot.model.llm_out.llama_cpp_llm import generate_stream + + return generate_stream + + register_llm_model_chat_adapter(VicunaChatAdapter) register_llm_model_chat_adapter(ChatGLMChatAdapter) register_llm_model_chat_adapter(GuanacoChatAdapter) @@ -248,6 +277,7 @@ register_llm_model_chat_adapter(GPT4AllChatAdapter) register_llm_model_chat_adapter(Llama2ChatAdapter) register_llm_model_chat_adapter(BaichuanChatAdapter) register_llm_model_chat_adapter(WizardLMChatAdapter) +register_llm_model_chat_adapter(LlamaCppChatAdapter) # Proxy model for test and develop, it's cheap for us now. register_llm_model_chat_adapter(ProxyllmChatAdapter) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 21789b9c8..10cdb0299 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH) from pilot.configs.config import Config from pilot.configs.model_config import * from pilot.model.llm_out.vicuna_base_llm import get_embeddings -from pilot.model.loader import ModelLoader +from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.server.chat_adapter import get_llm_chat_adapter from pilot.scene.base_message import ModelMessage @@ -34,12 +34,13 @@ class ModelWorker: def __init__(self, model_path, model_name, device): if model_path.endswith("/"): model_path = model_path[:-1] - self.model_name = model_name or model_path.split("/")[-1] + model_path = _get_model_real_path(model_name, model_path) + # self.model_name = model_name or model_path.split("/")[-1] self.device = device - print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") - self.ml: ModelLoader = ModelLoader( - model_path=model_path, model_name=self.model_name + print( + f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......" ) + self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name) self.model, self.tokenizer = self.ml.loader( load_8bit=CFG.IS_LOAD_8BIT, load_4bit=CFG.IS_LOAD_4BIT, @@ -60,7 +61,7 @@ class ModelWorker: else: self.context_len = 2048 - self.llm_chat_adapter = get_llm_chat_adapter(model_path) + self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( model_path ) @@ -86,7 +87,7 @@ class ModelWorker: try: # params adaptation params, model_context = self.llm_chat_adapter.model_adaptation( - params, self.ml.model_path + params, self.ml.model_path, prompt_template=self.ml.prompt_template ) for output in self.generate_stream_func( self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS diff --git a/requirements.txt b/requirements.txt index 008560eb3..6340dcac0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,7 @@ gradio-client==0.0.8 wandb llama-index==0.5.27 +# TODO move bitsandbytes to optional bitsandbytes accelerate>=0.20.3 diff --git a/setup.py b/setup.py index 5c4a4d7f4..b33979f78 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,11 @@ -from typing import List +from typing import List, Tuple import setuptools +import platform +import subprocess +import os +from enum import Enum + from setuptools import find_packages with open("README.md", "r") as fh: @@ -16,6 +21,117 @@ def parse_requirements(file_name: str) -> List[str]: ] +class SetupSpec: + def __init__(self) -> None: + self.extras: dict = {} + + +setup_spec = SetupSpec() + + +class AVXType(Enum): + BASIC = "basic" + AVX = "AVX" + AVX2 = "AVX2" + AVX512 = "AVX512" + + @staticmethod + def of_type(avx: str): + for item in AVXType: + if item._value_ == avx: + return item + return None + + +class OSType(Enum): + WINDOWS = "win" + LINUX = "linux" + DARWIN = "darwin" + OTHER = "other" + + +def get_cpu_avx_support() -> Tuple[OSType, AVXType]: + system = platform.system() + os_type = OSType.OTHER + cpu_avx = AVXType.BASIC + env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX")) + + cmds = ["lscpu"] + if system == "Windows": + cmds = ["coreinfo"] + os_type = OSType.WINDOWS + elif system == "Linux": + cmds = ["lscpu"] + os_type = OSType.LINUX + elif system == "Darwin": + cmds = ["sysctl", "-a"] + os_type = OSType.DARWIN + else: + os_type = OSType.OTHER + print("Unsupported OS to get cpu avx, use default") + return os_type, env_cpu_avx if env_cpu_avx else cpu_avx + result = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output = result.stdout.decode() + if "avx512" in output.lower(): + cpu_avx = AVXType.AVX512 + elif "avx2" in output.lower(): + cpu_avx = AVXType.AVX2 + elif "avx " in output.lower(): + # cpu_avx = AVXType.AVX + pass + return os_type, env_cpu_avx if env_cpu_avx else cpu_avx + + +def get_cuda_version() -> str: + try: + import torch + + return torch.version.cuda + except Exception: + return None + + +def llama_cpp_python_cuda_requires(): + cuda_version = get_cuda_version() + device = "cpu" + if not cuda_version: + print("CUDA not support, use cpu version") + return + device = "cu" + cuda_version.replace(".", "") + os_type, cpu_avx = get_cpu_avx_support() + supported_os = [OSType.WINDOWS, OSType.LINUX] + if os_type not in supported_os: + print( + f"llama_cpp_python_cuda just support in os: {[r._value_ for r in supported_os]}" + ) + return + cpu_avx = cpu_avx._value_ + base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui" + llama_cpp_version = "0.1.77" + py_version = "cp310" + os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64" + extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl" + print(f"Install llama_cpp_python_cuda from {extra_index_url}") + + setup_spec.extras["llama_cpp"].append(f"llama_cpp_python_cuda @ {extra_index_url}") + + +def llama_cpp_requires(): + setup_spec.extras["llama_cpp"] = ["llama-cpp-python"] + llama_cpp_python_cuda_requires() + + +def all_requires(): + requires = set() + for _, pkgs in setup_spec.extras.items(): + for pkg in pkgs: + requires.add(pkg) + setup_spec.extras["all"] = list(requires) + + +llama_cpp_requires() +all_requires() + setuptools.setup( name="db-gpt", packages=find_packages(), @@ -27,9 +143,10 @@ setuptools.setup( long_description=long_description, long_description_content_type="text/markdown", install_requires=parse_requirements("requirements.txt"), - url="https://github.com/csunny/DB-GPT", + url="https://github.com/eosphoros-ai/DB-GPT", license="https://opensource.org/license/mit/", python_requires=">=3.10", + extras_require=setup_spec.extras, entry_points={ "console_scripts": [ "dbgpt_server=pilot.server:webserver",