From 12c4cf8de5f48b962c091016ee0fc30a0e7a3cc3 Mon Sep 17 00:00:00 2001
From: Aries-ckt <916701291@qq.com>
Date: Tue, 15 Aug 2023 12:34:46 +0800
Subject: [PATCH 1/4] doc:Update faq.md
---
docs/faq.md | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/docs/faq.md b/docs/faq.md
index 0255e55f2..ddd60bd43 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -51,15 +51,18 @@ Normal:
##### Q4:When I use openai(MODEL_SERVER=proxyllm) to chat
-
+
##### A4: make sure your openapi API_KEY is available
##### Q5:When I Chat Data and Chat Meta Data, I found the error
-
-
+
+
+]()
+
##### A5: you have not create your database and table
1.create your database.
@@ -103,7 +106,7 @@ VECTOR_STORE_TYPE=Chroma
```
##### Q7:When I use vicuna-13b, found some illegal character like this.
-
+
##### A7: set KNOWLEDGE_SEARCH_TOP_SIZE smaller or set KNOWLEDGE_CHUNK_SIZE smaller, and reboot server.
From b5fd5d2a3a5ded36c62dc616df281de963ff2cb7 Mon Sep 17 00:00:00 2001
From: FangYin Cheng
Date: Tue, 15 Aug 2023 18:58:15 +0800
Subject: [PATCH 2/4] feat: Support llama.cpp
---
.env.template | 8 ++
docker/base/Dockerfile | 32 +++--
pilot/configs/model_config.py | 1 +
pilot/model/adapter.py | 91 +++++++++++++-
pilot/model/llm/llama_cpp/__init__.py | 0
pilot/model/llm/llama_cpp/llama_cpp.py | 145 ++++++++++++++++++++++
pilot/model/llm_out/llama_cpp_llm.py | 8 ++
pilot/model/loader.py | 104 ++++++++++------
pilot/model/parameter.py | 151 +++++++++++++++++++++++
pilot/scene/chat_dashboard/out_parser.py | 2 +
pilot/scene/chat_dashboard/prompt.py | 2 +-
pilot/server/chat_adapter.py | 44 +++++--
pilot/server/llmserver.py | 15 +--
requirements.txt | 1 +
setup.py | 121 +++++++++++++++++-
15 files changed, 652 insertions(+), 73 deletions(-)
create mode 100644 pilot/model/llm/llama_cpp/__init__.py
create mode 100644 pilot/model/llm/llama_cpp/llama_cpp.py
create mode 100644 pilot/model/llm_out/llama_cpp_llm.py
create mode 100644 pilot/model/parameter.py
diff --git a/.env.template b/.env.template
index f53bf31f7..d3187d0a1 100644
--- a/.env.template
+++ b/.env.template
@@ -37,6 +37,14 @@ QUANTIZE_8bit=True
## "PROXYLLM_BACKEND" is the model they actually deploy. We can use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
# PROXYLLM_BACKEND=
+### You can configure parameters for a specific model with {model name}_{config key}=xxx
+### See /pilot/model/parameter.py
+## prompt template for current model
+# llama_cpp_prompt_template=vicuna_v1.1
+## llama-2-70b must be 8
+# llama_cpp_n_gqa=8
+## Model path
+# llama_cpp_model_path=/data/models/TheBloke/vicuna-7B-v1.5-GGML/vicuna-7b-v1.5.ggmlv3.q4_0.bin
#*******************************************************************#
#** EMBEDDING SETTINGS **#
diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile
index 7df9e2d4e..4075444a9 100644
--- a/docker/base/Dockerfile
+++ b/docker/base/Dockerfile
@@ -11,19 +11,30 @@ ARG LANGUAGE="en"
ARG PIP_INDEX_URL="https://pypi.org/simple"
ENV PIP_INDEX_URL=$PIP_INDEX_URL
+RUN mkdir -p /app
+
# COPY only requirements.txt first to leverage Docker cache
-COPY ./requirements.txt /tmp/requirements.txt
+COPY ./requirements.txt /app/requirements.txt
+COPY ./setup.py /app/setup.py
+COPY ./README.md /app/README.md
WORKDIR /app
+# RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
+# && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \
+# # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip
+# then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \
+# && cp /app/requirements.txt /tmp/requirements.txt; \
+# fi;) \
+# && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \
+# && rm /tmp/requirements.txt
+
RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
- && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \
- # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip
- then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \
- && cp /app/requirements.txt /tmp/requirements.txt; \
- fi;) \
- && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \
- && rm /tmp/requirements.txt
+ && cd /app && pip3 install -i $PIP_INDEX_URL .
+
+# ENV CMAKE_ARGS="-DLLAMA_CUBLAS=ON -DLLAMA_AVX2=OFF -DLLAMA_F16C=OFF -DLLAMA_FMA=OFF"
+# ENV FORCE_CMAKE=1
+RUN cd /app && pip3 install -i $PIP_INDEX_URL .[llama_cpp]
RUN (if [ "${LANGUAGE}" = "zh" ]; \
# language is zh, download zh_core_web_sm from github
@@ -37,12 +48,11 @@ RUN (if [ "${LANGUAGE}" = "zh" ]; \
ARG BUILD_LOCAL_CODE="false"
# COPY the rest of the app
-COPY . /tmp/app
+COPY . /app
# TODOļ¼Need to find a better way to determine whether to build docker image with local code.
RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \
- then mv /tmp/app / && rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \
- else rm -rf /tmp/app; \
+ then rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \
fi;)
ARG LOAD_EXAMPLES="true"
diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py
index 9ca56824a..044572ca9 100644
--- a/pilot/configs/model_config.py
+++ b/pilot/configs/model_config.py
@@ -64,6 +64,7 @@ LLM_MODEL_CONFIG = {
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
+ "llama-cpp": os.path.join(MODEL_PATH, "ggml-model-q4_0.bin"),
}
# Load model config
diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py
index 1c8562e78..d97d2cb2b 100644
--- a/pilot/model/adapter.py
+++ b/pilot/model/adapter.py
@@ -3,7 +3,9 @@
import torch
import os
-from typing import List
+import re
+from pathlib import Path
+from typing import List, Tuple
from functools import cache
from transformers import (
AutoModel,
@@ -12,8 +14,10 @@ from transformers import (
LlamaTokenizer,
BitsAndBytesConfig,
)
+from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
from pilot.configs.model_config import DEVICE
from pilot.configs.config import Config
+from pilot.logs import logger
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
@@ -24,6 +28,14 @@ bnb_config = BitsAndBytesConfig(
CFG = Config()
+class ModelType:
+ """ "Type of model"""
+
+ HF = "huggingface"
+ LLAMA_CPP = "llama.cpp"
+ # TODO, support more model type
+
+
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT"""
@@ -31,8 +43,17 @@ class BaseLLMAdaper:
def use_fast_tokenizer(self) -> bool:
return False
+ def model_type(self) -> str:
+ return ModelType.HF
+
+ def model_param_class(self, model_type: str = None) -> ModelParameters:
+ model_type = model_type if model_type else self.model_type()
+ if model_type == ModelType.LLAMA_CPP:
+ return LlamaCppModelParameters
+ return ModelParameters
+
def match(self, model_path: str):
- return True
+ return False
def loader(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
@@ -52,12 +73,25 @@ def register_llm_model_adapters(cls):
@cache
-def get_llm_model_adapter(model_path: str) -> BaseLLMAdaper:
+def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
+ # Prefer using model name matching
for adapter in llm_model_adapters:
- if adapter.match(model_path):
+ if adapter.match(model_name):
+ logger.info(
+ f"Found llm model adapter with model name: {model_name}, {adapter}"
+ )
return adapter
- raise ValueError(f"Invalid model adapter for {model_path}")
+ for adapter in llm_model_adapters:
+ if adapter.match(model_path):
+ logger.info(
+ f"Found llm model adapter with model path: {model_path}, {adapter}"
+ )
+ return adapter
+
+ raise ValueError(
+ f"Invalid model adapter for model name {model_name} and model path {model_path}"
+ )
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
@@ -296,6 +330,52 @@ class WizardLMAdapter(BaseLLMAdaper):
return "wizardlm" in model_path.lower()
+class LlamaCppAdapater(BaseLLMAdaper):
+ @staticmethod
+ def _parse_model_path(model_path: str) -> Tuple[bool, str]:
+ path = Path(model_path)
+ if not path.exists():
+ # Just support local model
+ return False, None
+ if not path.is_file():
+ model_paths = list(path.glob("*ggml*.bin"))
+ if not model_paths:
+ return False
+ model_path = str(model_paths[0])
+ logger.warn(
+ f"Model path {model_path} is not single file, use first *gglm*.bin model file: {model_path}"
+ )
+ if not re.fullmatch(".*ggml.*\.bin", model_path):
+ return False, None
+ return True, model_path
+
+ def model_type(self) -> ModelType:
+ return ModelType.LLAMA_CPP
+
+ def match(self, model_path: str):
+ """
+ https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML
+ """
+ if "llama-cpp" == model_path:
+ return True
+ is_match, _ = LlamaCppAdapater._parse_model_path(model_path)
+ return is_match
+
+ def loader(self, model_path: str, from_pretrained_kwargs: dict):
+ # TODO not support yet
+ _, model_path = LlamaCppAdapater._parse_model_path(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path, trust_remote_code=True, use_fast=False
+ )
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ **from_pretrained_kwargs,
+ )
+ return model, tokenizer
+
+
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
@@ -305,6 +385,7 @@ register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
+register_llm_model_adapters(LlamaCppAdapater)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later
diff --git a/pilot/model/llm/llama_cpp/__init__.py b/pilot/model/llm/llama_cpp/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/model/llm/llama_cpp/llama_cpp.py b/pilot/model/llm/llama_cpp/llama_cpp.py
new file mode 100644
index 000000000..5029d9192
--- /dev/null
+++ b/pilot/model/llm/llama_cpp/llama_cpp.py
@@ -0,0 +1,145 @@
+"""
+Fork from text-generation-webui https://github.com/oobabooga/text-generation-webui/blob/main/modules/llamacpp_model.py
+"""
+import re
+from typing import Dict, Any
+import torch
+import llama_cpp
+
+from pilot.model.parameter import LlamaCppModelParameters
+from pilot.logs import logger
+
+if torch.cuda.is_available() and not torch.version.hip:
+ try:
+ import llama_cpp_cuda
+ except:
+ llama_cpp_cuda = None
+else:
+ llama_cpp_cuda = None
+
+
+def llama_cpp_lib(prefer_cpu: bool = False):
+ if prefer_cpu or llama_cpp_cuda is None:
+ logger.info(f"Llama.cpp use cpu")
+ return llama_cpp
+ else:
+ return llama_cpp_cuda
+
+
+def ban_eos_logits_processor(eos_token, input_ids, logits):
+ logits[eos_token] = -float("inf")
+ return logits
+
+
+def get_params(model_path: str, model_params: LlamaCppModelParameters) -> Dict:
+ return {
+ "model_path": model_path,
+ "n_ctx": model_params.max_context_size,
+ "seed": model_params.seed,
+ "n_threads": model_params.n_threads,
+ "n_batch": model_params.n_batch,
+ "use_mmap": True,
+ "use_mlock": False,
+ "low_vram": False,
+ "n_gpu_layers": 0 if model_params.prefer_cpu else model_params.n_gpu_layers,
+ "n_gqa": model_params.n_gqa,
+ "logits_all": True,
+ "rms_norm_eps": model_params.rms_norm_eps,
+ }
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+ self.model = None
+ self.verbose = True
+
+ def __del__(self):
+ if self.model:
+ self.model.__del__()
+
+ @classmethod
+ def from_pretrained(self, model_path, model_params: LlamaCppModelParameters):
+ Llama = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).Llama
+ LlamaCache = llama_cpp_lib(prefer_cpu=model_params.prefer_cpu).LlamaCache
+
+ result = self()
+ cache_capacity = 0
+ cache_capacity_str = model_params.cache_capacity
+ if cache_capacity_str is not None:
+ if "GiB" in cache_capacity_str:
+ cache_capacity = (
+ int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000 * 1000
+ )
+ elif "MiB" in cache_capacity_str:
+ cache_capacity = (
+ int(re.sub("[a-zA-Z]", "", cache_capacity_str)) * 1000 * 1000
+ )
+ else:
+ cache_capacity = int(cache_capacity_str)
+
+ params = get_params(model_path, model_params)
+ logger.info("Cache capacity is " + str(cache_capacity) + " bytes")
+ logger.info(f"Load LLama model with params: {params}")
+
+ result.model = Llama(**params)
+ result.verbose = model_params.verbose
+ if cache_capacity > 0:
+ result.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
+
+ # This is ugly, but the model and the tokenizer are the same object in this library.
+ return result, result
+
+ def encode(self, string):
+ if type(string) is str:
+ string = string.encode()
+
+ return self.model.tokenize(string)
+
+ def decode(self, tokens):
+ return self.model.detokenize(tokens)
+
+ def generate_streaming(self, params, context_len: int):
+ # LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
+
+ # Read parameters
+ prompt = params["prompt"]
+ if self.verbose:
+ print(f"Prompt of model: \n{prompt}")
+
+ temperature = float(params.get("temperature", 1.0))
+ repetition_penalty = float(params.get("repetition_penalty", 1.1))
+ top_p = float(params.get("top_p", 1.0))
+ top_k = int(params.get("top_k", -1)) # -1 means disable
+ max_new_tokens = int(params.get("max_new_tokens", 2048))
+ echo = bool(params.get("echo", True))
+
+ max_src_len = context_len - max_new_tokens
+ # Handle truncation
+ prompt = self.encode(prompt)
+ prompt = prompt[-max_src_len:]
+ prompt = self.decode(prompt).decode("utf-8")
+
+ # TODO Compared with the original llama model, the Chinese effect of llama.cpp is very general, and it needs to be debugged
+ completion_chunks = self.model.create_completion(
+ prompt=prompt,
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ repeat_penalty=repetition_penalty,
+ # tfs_z=params['tfs'],
+ # mirostat_mode=int(params['mirostat_mode']),
+ # mirostat_tau=params['mirostat_tau'],
+ # mirostat_eta=params['mirostat_eta'],
+ stream=True,
+ echo=echo,
+ logits_processor=None,
+ )
+
+ output = ""
+ for completion_chunk in completion_chunks:
+ text = completion_chunk["choices"][0]["text"]
+ output += text
+ # print(output)
+ yield output
diff --git a/pilot/model/llm_out/llama_cpp_llm.py b/pilot/model/llm_out/llama_cpp_llm.py
new file mode 100644
index 000000000..921670065
--- /dev/null
+++ b/pilot/model/llm_out/llama_cpp_llm.py
@@ -0,0 +1,8 @@
+from typing import Dict
+import torch
+
+
+@torch.inference_mode()
+def generate_stream(model, tokenizer, params: Dict, device: str, context_len: int):
+ # Just support LlamaCppModel
+ return model.generate_streaming(params=params, context_len=context_len)
diff --git a/pilot/model/loader.py b/pilot/model/loader.py
index c1af72e6e..170f7c460 100644
--- a/pilot/model/loader.py
+++ b/pilot/model/loader.py
@@ -2,48 +2,24 @@
# -*- coding: utf-8 -*-
from typing import Optional, Dict
-import dataclasses
import torch
from pilot.configs.model_config import DEVICE
-from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
+from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.compression import compress_module
+from pilot.model.parameter import (
+ EnvArgumentParser,
+ ModelParameters,
+ LlamaCppModelParameters,
+ _genenv_ignoring_key_case,
+)
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.singleton import Singleton
from pilot.utils import get_gpu_memory
from pilot.logs import logger
-class ModelType:
- """ "Type of model"""
-
- HF = "huggingface"
- LLAMA_CPP = "llama.cpp"
- # TODO, support more model type
-
-
-@dataclasses.dataclass
-class ModelParams:
- device: str
- model_name: str
- model_path: str
- model_type: Optional[str] = ModelType.HF
- num_gpus: Optional[int] = None
- max_gpu_memory: Optional[str] = None
- cpu_offloading: Optional[bool] = False
- load_8bit: Optional[bool] = True
- load_4bit: Optional[bool] = False
- # quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float)
- quant_type: Optional[str] = "nf4"
- # Nested quantization is activated through `use_double_quant``
- use_double_quant: Optional[bool] = True
- # "bfloat16", "float16", "float32"
- compute_dtype: Optional[str] = None
- debug: Optional[bool] = False
- trust_remote_code: Optional[bool] = True
-
-
-def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
+def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
model_name = model_params.model_name.lower()
@@ -51,7 +27,7 @@ def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
return any(m in model_name for m in supported_models)
-def _check_quantization(model_params: ModelParams):
+def _check_quantization(model_params: ModelParameters):
model_name = model_params.model_name.lower()
has_quantization = any([model_params.load_8bit or model_params.load_4bit])
if has_quantization:
@@ -69,6 +45,21 @@ def _check_quantization(model_params: ModelParams):
return has_quantization
+def _get_model_real_path(model_name, default_model_path) -> str:
+ """Get model real path by model name
+ priority from high to low:
+ 1. environment variable with key: {model_name}_model_path
+ 2. environment variable with key: model_path
+ 3. default_model_path
+ """
+ env_prefix = model_name + "_"
+ env_prefix = env_prefix.replace("-", "_")
+ env_model_path = _genenv_ignoring_key_case("model_path", env_prefix=env_prefix)
+ if env_model_path:
+ return env_model_path
+ return _genenv_ignoring_key_case("model_path", default_value=default_model_path)
+
+
class ModelLoader(metaclass=Singleton):
"""Model loader is a class for model load
@@ -83,6 +74,7 @@ class ModelLoader(metaclass=Singleton):
self.device = DEVICE
self.model_path = model_path
self.model_name = model_name
+ self.prompt_template: str = None
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
@@ -97,7 +89,18 @@ class ModelLoader(metaclass=Singleton):
cpu_offloading=False,
max_gpu_memory: Optional[str] = None,
):
- model_params = ModelParams(
+ llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
+ model_type = llm_adapter.model_type()
+ param_cls = llm_adapter.model_param_class(model_type)
+
+ args_parser = EnvArgumentParser()
+ # Read the parameters of the model from the environment variable according to the model name prefix, which currently has the highest priority
+ # vicuna_13b_max_gpu_memory=13Gib or VICUNA_13B_MAX_GPU_MEMORY=13Gib
+ env_prefix = self.model_name + "_"
+ env_prefix = env_prefix.replace("-", "_")
+ model_params = args_parser.parse_args_into_dataclass(
+ param_cls,
+ env_prefix=env_prefix,
device=self.device,
model_path=self.model_path,
model_name=self.model_name,
@@ -105,14 +108,21 @@ class ModelLoader(metaclass=Singleton):
cpu_offloading=cpu_offloading,
load_8bit=load_8bit,
load_4bit=load_4bit,
- debug=debug,
+ verbose=debug,
)
+ self.prompt_template = model_params.prompt_template
+
logger.info(f"model_params:\n{model_params}")
- llm_adapter = get_llm_model_adapter(model_params.model_path)
- return huggingface_loader(llm_adapter, model_params)
+
+ if model_type == ModelType.HF:
+ return huggingface_loader(llm_adapter, model_params)
+ elif model_type == ModelType.LLAMA_CPP:
+ return llamacpp_loader(llm_adapter, model_params)
+ else:
+ raise Exception(f"Unkown model type {model_type}")
-def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
+def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
device = model_params.device
max_memory = None
@@ -175,14 +185,14 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
pass
except AttributeError:
pass
- if model_params.debug:
+ if model_params.verbose:
print(model)
return model, tokenizer
def load_huggingface_quantization_model(
llm_adapter: BaseLLMAdaper,
- model_params: ModelParams,
+ model_params: ModelParameters,
kwargs: Dict,
max_memory: Dict[int, str],
):
@@ -312,3 +322,17 @@ def load_huggingface_quantization_model(
)
return model, tokenizer
+
+
+def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters):
+ try:
+ from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel
+ except ImportError as exc:
+ raise ValueError(
+ "Could not import python package: llama-cpp-python "
+ "Please install db-gpt llama support with `cd $DB-GPT-DIR && pip install .[llama_cpp]` "
+ "or install llama-cpp-python with `pip install llama-cpp-python`"
+ ) from exc
+ model_path = model_params.model_path
+ model, tokenizer = LlamaCppModel.from_pretrained(model_path, model_params)
+ return model, tokenizer
diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py
new file mode 100644
index 000000000..6ede45d9b
--- /dev/null
+++ b/pilot/model/parameter.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import os
+
+from typing import Any, Optional, Type
+from dataclasses import dataclass, field, fields
+
+from pilot.model.conversation import conv_templates
+
+suported_prompt_templates = ",".join(conv_templates.keys())
+
+
+def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
+ """Get the value from the environment variable, ignoring the case of the key"""
+ if env_prefix:
+ env_key = env_prefix + env_key
+ return os.getenv(
+ env_key, os.getenv(env_key.upper(), os.getenv(env_key.lower(), default_value))
+ )
+
+
+class EnvArgumentParser:
+ def parse_args_into_dataclass(
+ self, dataclass_type: Type, env_prefix: str = None, **kwargs
+ ) -> Any:
+ for field in fields(dataclass_type):
+ env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
+ if env_var_value:
+ env_var_value = env_var_value.strip()
+ if field.type is int or field.type == Optional[int]:
+ env_var_value = int(env_var_value)
+ elif field.type is float or field.type == Optional[float]:
+ env_var_value = float(env_var_value)
+ elif field.type is bool or field.type == Optional[bool]:
+ env_var_value = env_var_value.lower() == "true"
+ elif field.type is str or field.type == Optional[str]:
+ pass
+ else:
+ raise ValueError(f"Unsupported parameter type {field.type}")
+ kwargs[field.name] = env_var_value
+ return dataclass_type(**kwargs)
+
+
+@dataclass
+class ModelParameters:
+ device: str = field(metadata={"help": "Device to run model"})
+ model_name: str = field(metadata={"help": "Model name"})
+ model_path: str = field(metadata={"help": "Model path"})
+ model_type: Optional[str] = field(
+ default="huggingface", metadata={"help": "Model type, huggingface or llama.cpp"}
+ )
+ prompt_template: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
+ },
+ )
+ max_context_size: Optional[int] = field(
+ default=4096, metadata={"help": "Maximum context size"}
+ )
+
+ num_gpus: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "The number of gpus you expect to use, if it is empty, use all of them as much as possible"
+ },
+ )
+ max_gpu_memory: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "The maximum memory limit of each GPU, only valid in multi-GPU configuration"
+ },
+ )
+ cpu_offloading: Optional[bool] = field(
+ default=False, metadata={"help": "CPU offloading"}
+ )
+ load_8bit: Optional[bool] = field(
+ default=False, metadata={"help": "8-bit quantization"}
+ )
+ load_4bit: Optional[bool] = field(
+ default=False, metadata={"help": "4-bit quantization"}
+ )
+ quant_type: Optional[str] = field(
+ default="nf4",
+ metadata={
+ "valid_values": ["nf4", "fp4"],
+ "help": "Quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float), only valid when load_4bit=True",
+ },
+ )
+ use_double_quant: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Nested quantization, only valid when load_4bit=True"},
+ )
+ # "bfloat16", "float16", "float32"
+ compute_dtype: Optional[str] = field(
+ default=None,
+ metadata={
+ "valid_values": ["bfloat16", "float16", "float32"],
+ "help": "Model compute type",
+ },
+ )
+ trust_remote_code: Optional[bool] = field(
+ default=True, metadata={"help": "Trust remote code"}
+ )
+ verbose: Optional[bool] = field(
+ default=False, metadata={"help": "Show verbose output."}
+ )
+
+
+@dataclass
+class LlamaCppModelParameters(ModelParameters):
+ seed: Optional[int] = field(
+ default=-1, metadata={"help": "Random seed for llama-cpp models. -1 for random"}
+ )
+ n_threads: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "Number of threads to use. If None, the number of threads is automatically determined"
+ },
+ )
+ n_batch: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "Maximum number of prompt tokens to batch together when calling llama_eval"
+ },
+ )
+ n_gpu_layers: Optional[int] = field(
+ default=1000000000,
+ metadata={
+ "help": "Number of layers to offload to the GPU, Set this to 1000000000 to offload all layers to the GPU."
+ },
+ )
+ n_gqa: Optional[int] = field(
+ default=None,
+ metadata={"help": "Grouped-query attention. Must be 8 for llama-2 70b."},
+ )
+ rms_norm_eps: Optional[float] = field(
+ default=5e-06, metadata={"help": "5e-6 is a good value for llama-2 models."}
+ )
+ cache_capacity: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. "
+ },
+ )
+ prefer_cpu: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured."
+ },
+ )
diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py
index 00bdc4179..7f799f585 100644
--- a/pilot/scene/chat_dashboard/out_parser.py
+++ b/pilot/scene/chat_dashboard/out_parser.py
@@ -29,6 +29,8 @@ class ChatDashboardOutputParser(BaseOutputParser):
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
chart_items: List[ChartItem] = []
+ if not isinstance(response, list):
+ response = [response]
for item in response:
chart_items.append(
ChartItem(
diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py
index 72c429ea7..8f4b85385 100644
--- a/pilot/scene/chat_dashboard/prompt.py
+++ b/pilot/scene/chat_dashboard/prompt.py
@@ -20,7 +20,7 @@ The output data of the analysis cannot exceed 4 columns, and do not use columns
According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type:
{supported_chat_type}
-Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
+Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens
Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title, display method and summary of brief analysis thinking, and respond in the following json format:
{response}
diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py
index ab3aec94e..07b44b28c 100644
--- a/pilot/server/chat_adapter.py
+++ b/pilot/server/chat_adapter.py
@@ -13,7 +13,7 @@ class BaseChatAdpter:
and fetch output from model"""
def match(self, model_path: str):
- return True
+ return False
def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
@@ -24,7 +24,9 @@ class BaseChatAdpter:
def get_conv_template(self, model_path: str) -> Conversation:
return None
- def model_adaptation(self, params: Dict, model_path: str) -> Tuple[Dict, Dict]:
+ def model_adaptation(
+ self, params: Dict, model_path: str, prompt_template: str = None
+ ) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_conv_template(model_path)
messages = params.get("messages")
@@ -39,6 +41,10 @@ class BaseChatAdpter:
]
params["messages"] = messages
+ if prompt_template:
+ print(f"Use prompt template {prompt_template} from config")
+ conv = get_conv_template(prompt_template)
+
if not conv or not messages:
# Nothing to do
print(
@@ -94,14 +100,19 @@ def register_llm_model_chat_adapter(cls):
@cache
-def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
+def get_llm_chat_adapter(model_name: str, model_path: str) -> BaseChatAdpter:
"""Get a chat generate func for a model"""
for adapter in llm_model_chat_adapters:
- if adapter.match(model_path):
- print(f"Get model path: {model_path} adapter {adapter}")
+ if adapter.match(model_name):
+ print(f"Get model chat adapter with model name {model_name}, {adapter}")
return adapter
-
- raise ValueError(f"Invalid model for chat adapter {model_path}")
+ for adapter in llm_model_chat_adapters:
+ if adapter.match(model_path):
+ print(f"Get model chat adapter with model path {model_path}, {adapter}")
+ return adapter
+ raise ValueError(
+ f"Invalid model for chat adapter with model name {model_name} and model path {model_path}"
+ )
class VicunaChatAdapter(BaseChatAdpter):
@@ -239,6 +250,24 @@ class WizardLMChatAdapter(BaseChatAdpter):
return get_conv_template("vicuna_v1.1")
+class LlamaCppChatAdapter(BaseChatAdpter):
+ def match(self, model_path: str):
+ from pilot.model.adapter import LlamaCppAdapater
+
+ if "llama-cpp" == model_path:
+ return True
+ is_match, _ = LlamaCppAdapater._parse_model_path(model_path)
+ return is_match
+
+ def get_conv_template(self, model_path: str) -> Conversation:
+ return get_conv_template("llama-2")
+
+ def get_generate_stream_func(self, model_path: str):
+ from pilot.model.llm_out.llama_cpp_llm import generate_stream
+
+ return generate_stream
+
+
register_llm_model_chat_adapter(VicunaChatAdapter)
register_llm_model_chat_adapter(ChatGLMChatAdapter)
register_llm_model_chat_adapter(GuanacoChatAdapter)
@@ -248,6 +277,7 @@ register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
+register_llm_model_chat_adapter(LlamaCppChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)
diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py
index 21789b9c8..10cdb0299 100644
--- a/pilot/server/llmserver.py
+++ b/pilot/server/llmserver.py
@@ -23,7 +23,7 @@ sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
-from pilot.model.loader import ModelLoader
+from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
@@ -34,12 +34,13 @@ class ModelWorker:
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
- self.model_name = model_name or model_path.split("/")[-1]
+ model_path = _get_model_real_path(model_name, model_path)
+ # self.model_name = model_name or model_path.split("/")[-1]
self.device = device
- print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
- self.ml: ModelLoader = ModelLoader(
- model_path=model_path, model_name=self.model_name
+ print(
+ f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
)
+ self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
self.model, self.tokenizer = self.ml.loader(
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
@@ -60,7 +61,7 @@ class ModelWorker:
else:
self.context_len = 2048
- self.llm_chat_adapter = get_llm_chat_adapter(model_path)
+ self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
@@ -86,7 +87,7 @@ class ModelWorker:
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
- params, self.ml.model_path
+ params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
diff --git a/requirements.txt b/requirements.txt
index 008560eb3..6340dcac0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -46,6 +46,7 @@ gradio-client==0.0.8
wandb
llama-index==0.5.27
+# TODO move bitsandbytes to optional
bitsandbytes
accelerate>=0.20.3
diff --git a/setup.py b/setup.py
index 5c4a4d7f4..b33979f78 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,11 @@
-from typing import List
+from typing import List, Tuple
import setuptools
+import platform
+import subprocess
+import os
+from enum import Enum
+
from setuptools import find_packages
with open("README.md", "r") as fh:
@@ -16,6 +21,117 @@ def parse_requirements(file_name: str) -> List[str]:
]
+class SetupSpec:
+ def __init__(self) -> None:
+ self.extras: dict = {}
+
+
+setup_spec = SetupSpec()
+
+
+class AVXType(Enum):
+ BASIC = "basic"
+ AVX = "AVX"
+ AVX2 = "AVX2"
+ AVX512 = "AVX512"
+
+ @staticmethod
+ def of_type(avx: str):
+ for item in AVXType:
+ if item._value_ == avx:
+ return item
+ return None
+
+
+class OSType(Enum):
+ WINDOWS = "win"
+ LINUX = "linux"
+ DARWIN = "darwin"
+ OTHER = "other"
+
+
+def get_cpu_avx_support() -> Tuple[OSType, AVXType]:
+ system = platform.system()
+ os_type = OSType.OTHER
+ cpu_avx = AVXType.BASIC
+ env_cpu_avx = AVXType.of_type(os.getenv("DBGPT_LLAMA_CPP_AVX"))
+
+ cmds = ["lscpu"]
+ if system == "Windows":
+ cmds = ["coreinfo"]
+ os_type = OSType.WINDOWS
+ elif system == "Linux":
+ cmds = ["lscpu"]
+ os_type = OSType.LINUX
+ elif system == "Darwin":
+ cmds = ["sysctl", "-a"]
+ os_type = OSType.DARWIN
+ else:
+ os_type = OSType.OTHER
+ print("Unsupported OS to get cpu avx, use default")
+ return os_type, env_cpu_avx if env_cpu_avx else cpu_avx
+ result = subprocess.run(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output = result.stdout.decode()
+ if "avx512" in output.lower():
+ cpu_avx = AVXType.AVX512
+ elif "avx2" in output.lower():
+ cpu_avx = AVXType.AVX2
+ elif "avx " in output.lower():
+ # cpu_avx = AVXType.AVX
+ pass
+ return os_type, env_cpu_avx if env_cpu_avx else cpu_avx
+
+
+def get_cuda_version() -> str:
+ try:
+ import torch
+
+ return torch.version.cuda
+ except Exception:
+ return None
+
+
+def llama_cpp_python_cuda_requires():
+ cuda_version = get_cuda_version()
+ device = "cpu"
+ if not cuda_version:
+ print("CUDA not support, use cpu version")
+ return
+ device = "cu" + cuda_version.replace(".", "")
+ os_type, cpu_avx = get_cpu_avx_support()
+ supported_os = [OSType.WINDOWS, OSType.LINUX]
+ if os_type not in supported_os:
+ print(
+ f"llama_cpp_python_cuda just support in os: {[r._value_ for r in supported_os]}"
+ )
+ return
+ cpu_avx = cpu_avx._value_
+ base_url = "https://github.com/jllllll/llama-cpp-python-cuBLAS-wheels/releases/download/textgen-webui"
+ llama_cpp_version = "0.1.77"
+ py_version = "cp310"
+ os_pkg_name = "linux_x86_64" if os_type == OSType.LINUX else "win_amd64"
+ extra_index_url = f"{base_url}/llama_cpp_python_cuda-{llama_cpp_version}+{device}{cpu_avx}-{py_version}-{py_version}-{os_pkg_name}.whl"
+ print(f"Install llama_cpp_python_cuda from {extra_index_url}")
+
+ setup_spec.extras["llama_cpp"].append(f"llama_cpp_python_cuda @ {extra_index_url}")
+
+
+def llama_cpp_requires():
+ setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
+ llama_cpp_python_cuda_requires()
+
+
+def all_requires():
+ requires = set()
+ for _, pkgs in setup_spec.extras.items():
+ for pkg in pkgs:
+ requires.add(pkg)
+ setup_spec.extras["all"] = list(requires)
+
+
+llama_cpp_requires()
+all_requires()
+
setuptools.setup(
name="db-gpt",
packages=find_packages(),
@@ -27,9 +143,10 @@ setuptools.setup(
long_description=long_description,
long_description_content_type="text/markdown",
install_requires=parse_requirements("requirements.txt"),
- url="https://github.com/csunny/DB-GPT",
+ url="https://github.com/eosphoros-ai/DB-GPT",
license="https://opensource.org/license/mit/",
python_requires=">=3.10",
+ extras_require=setup_spec.extras,
entry_points={
"console_scripts": [
"dbgpt_server=pilot.server:webserver",
From e5b03c8ab4015da3ce585a0d37abdbf46b5e4378 Mon Sep 17 00:00:00 2001
From: FangYin Cheng
Date: Wed, 16 Aug 2023 04:16:09 +0800
Subject: [PATCH 3/4] fix: fix issue #445, sqlite unable to open database file
---
docker/base/Dockerfile | 24 +++++++++----------
pilot/connections/rdbms/conn_sqlite.py | 4 ++--
.../rdbms/tests/test_conn_sqlite.py | 14 +++++++++++
3 files changed, 28 insertions(+), 14 deletions(-)
diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile
index 4075444a9..3cb063e47 100644
--- a/docker/base/Dockerfile
+++ b/docker/base/Dockerfile
@@ -11,6 +11,16 @@ ARG LANGUAGE="en"
ARG PIP_INDEX_URL="https://pypi.org/simple"
ENV PIP_INDEX_URL=$PIP_INDEX_URL
+RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
+ && (if [ "${LANGUAGE}" = "zh" ]; \
+ # language is zh, download zh_core_web_sm from github
+ then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
+ && pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \
+ && rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl; \
+ # not zh, download directly
+ else python3 -m spacy download zh_core_web_sm; \
+ fi;)
+
RUN mkdir -p /app
# COPY only requirements.txt first to leverage Docker cache
@@ -29,21 +39,11 @@ WORKDIR /app
# && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \
# && rm /tmp/requirements.txt
-RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
- && cd /app && pip3 install -i $PIP_INDEX_URL .
+RUN pip3 install -i $PIP_INDEX_URL .
# ENV CMAKE_ARGS="-DLLAMA_CUBLAS=ON -DLLAMA_AVX2=OFF -DLLAMA_F16C=OFF -DLLAMA_FMA=OFF"
# ENV FORCE_CMAKE=1
-RUN cd /app && pip3 install -i $PIP_INDEX_URL .[llama_cpp]
-
-RUN (if [ "${LANGUAGE}" = "zh" ]; \
- # language is zh, download zh_core_web_sm from github
- then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
- && pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \
- && rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl; \
- # not zh, download directly
- else python3 -m spacy download zh_core_web_sm; \
- fi;) \
+RUN pip3 install -i $PIP_INDEX_URL .[llama_cpp] \
&& rm -rf `pip3 cache dir`
ARG BUILD_LOCAL_CODE="false"
diff --git a/pilot/connections/rdbms/conn_sqlite.py b/pilot/connections/rdbms/conn_sqlite.py
index 339af025a..1740537cf 100644
--- a/pilot/connections/rdbms/conn_sqlite.py
+++ b/pilot/connections/rdbms/conn_sqlite.py
@@ -70,10 +70,10 @@ class SQLiteConnect(RDBMSDatabase):
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
- "SELECT name FROM sqlite_master WHERE type='table'"
+ text("SELECT name FROM sqlite_master WHERE type='table'")
)
view_results = self.session.execute(
- "SELECT name FROM sqlite_master WHERE type='view'"
+ text("SELECT name FROM sqlite_master WHERE type='view'")
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
diff --git a/pilot/connections/rdbms/tests/test_conn_sqlite.py b/pilot/connections/rdbms/tests/test_conn_sqlite.py
index efe4ddf76..01ef51878 100644
--- a/pilot/connections/rdbms/tests/test_conn_sqlite.py
+++ b/pilot/connections/rdbms/tests/test_conn_sqlite.py
@@ -121,3 +121,17 @@ def test_get_database_list(db):
def test_get_database_names(db):
db.get_database_names() == []
+
+
+def test_db_dir_exist_dir():
+ with tempfile.TemporaryDirectory() as temp_dir:
+ new_dir = os.path.join(temp_dir, "new_dir")
+ file_path = os.path.join(new_dir, "sqlite.db")
+ db = SQLiteConnect.from_file_path(file_path)
+ assert os.path.exists(new_dir) == True
+ assert list(db.get_table_names()) == []
+ with tempfile.TemporaryDirectory() as existing_dir:
+ file_path = os.path.join(existing_dir, "sqlite.db")
+ db = SQLiteConnect.from_file_path(file_path)
+ assert os.path.exists(existing_dir) == True
+ assert list(db.get_table_names()) == []
From 303efb9d4e4a73f9bbe7646811ae196ab7c1e523 Mon Sep 17 00:00:00 2001
From: FangYin Cheng
Date: Wed, 16 Aug 2023 04:23:50 +0800
Subject: [PATCH 4/4] feat: Split some main dependencies into optional
dependencies
---
pilot/openapi/api_v1/__init__.py | 0
pilot/vector_store/connector.py | 10 ++++++--
requirements.txt | 38 ++++++++++--------------------
requirements/test-requirements.txt | 10 ++++++++
setup.py | 26 +++++++++++++++++++-
5 files changed, 56 insertions(+), 28 deletions(-)
create mode 100644 pilot/openapi/api_v1/__init__.py
create mode 100644 requirements/test-requirements.txt
diff --git a/pilot/openapi/api_v1/__init__.py b/pilot/openapi/api_v1/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/pilot/vector_store/connector.py b/pilot/vector_store/connector.py
index ca56986c8..eaa202e72 100644
--- a/pilot/vector_store/connector.py
+++ b/pilot/vector_store/connector.py
@@ -1,9 +1,15 @@
from pilot.vector_store.chroma_store import ChromaStore
-from pilot.vector_store.milvus_store import MilvusStore
from pilot.vector_store.weaviate_store import WeaviateStore
-connector = {"Chroma": ChromaStore, "Milvus": MilvusStore, "Weaviate": WeaviateStore}
+connector = {"Chroma": ChromaStore, "Weaviate": WeaviateStore}
+
+try:
+ from pilot.vector_store.milvus_store import MilvusStore
+
+ connector["Milvus"] = MilvusStore
+except:
+ pass
class VectorStoreConnector:
diff --git a/requirements.txt b/requirements.txt
index 6340dcac0..55fdbadfb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,8 +5,8 @@ async-timeout==4.0.2
attrs==22.2.0
cchardet==2.1.7
chardet==5.1.0
-contourpy==1.0.7
-cycler==0.11.0
+# contourpy==1.0.7
+# cycler==0.11.0
filelock==3.9.0
fonttools==4.38.0
frozenlist==1.3.3
@@ -14,20 +14,20 @@ huggingface-hub==0.14.1
importlib-resources==5.12.0
sqlparse==0.4.4
-kiwisolver==1.4.4
-matplotlib==3.7.1
+# kiwisolver==1.4.4
+# matplotlib==3.7.1
multidict==6.0.4
packaging==23.0
psutil==5.9.4
-pycocotools==2.0.6
-pyparsing==3.0.9
+# pycocotools==2.0.6
+# pyparsing==3.0.9
python-dateutil==2.8.2
pyyaml==6.0
tokenizers==0.13.2
tqdm==4.64.1
transformers>=4.31.0
transformers_stream_generator
-timm==0.6.13
+# timm==0.6.13
spacy==3.5.3
webdataset==0.2.48
yarl==1.8.2
@@ -40,18 +40,17 @@ peft
pycocoevalcap
cpm_kernels
umap-learn
-notebook
+# notebook
gradio==3.23
gradio-client==0.0.8
-wandb
-llama-index==0.5.27
+# wandb
+# llama-index==0.5.27
# TODO move bitsandbytes to optional
bitsandbytes
accelerate>=0.20.3
unstructured==0.6.3
-grpcio==1.47.5
gpt4all==0.3.0
diskcache==5.6.1
@@ -61,7 +60,7 @@ gTTS==2.3.1
langchain
nltk
python-dotenv==1.0.0
-pymilvus==2.2.1
+
vcrpy
chromadb==0.3.22
markdown2
@@ -74,18 +73,7 @@ bardapi==0.1.29
# database
+# TODO moved to optional dependencies
pymysql
duckdb
-duckdb-engine
-pymssql
-
-# Testing dependencies
-pytest
-asynctest
-pytest-asyncio
-pytest-benchmark
-pytest-cov
-pytest-integration
-pytest-mock
-pytest-recording
-pytesseract==0.3.10
+duckdb-engine
\ No newline at end of file
diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt
new file mode 100644
index 000000000..c2fb321a5
--- /dev/null
+++ b/requirements/test-requirements.txt
@@ -0,0 +1,10 @@
+# Testing dependencies
+pytest
+asynctest
+pytest-asyncio
+pytest-benchmark
+pytest-cov
+pytest-integration
+pytest-mock
+pytest-recording
+pytesseract==0.3.10
\ No newline at end of file
diff --git a/setup.py b/setup.py
index b33979f78..5136f4fb8 100644
--- a/setup.py
+++ b/setup.py
@@ -117,10 +117,30 @@ def llama_cpp_python_cuda_requires():
def llama_cpp_requires():
+ """
+ pip install "db-gpt[llama_cpp]"
+ """
setup_spec.extras["llama_cpp"] = ["llama-cpp-python"]
llama_cpp_python_cuda_requires()
+def all_vector_store_requires():
+ """
+ pip install "db-gpt[vstore]"
+ """
+ setup_spec.extras["vstore"] = [
+ "grpcio==1.47.5", # maybe delete it
+ "pymilvus==2.2.1",
+ ]
+
+
+def all_datasource_requires():
+ """
+ pip install "db-gpt[datasource]"
+ """
+ setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
+
+
def all_requires():
requires = set()
for _, pkgs in setup_spec.extras.items():
@@ -130,11 +150,15 @@ def all_requires():
llama_cpp_requires()
+all_vector_store_requires()
+all_datasource_requires()
+
+# must be last
all_requires()
setuptools.setup(
name="db-gpt",
- packages=find_packages(),
+ packages=find_packages(exclude=("tests", "*.tests", "*.tests.*", "examples")),
version="0.3.5",
author="csunny",
author_email="cfqcsunny@gmail.com",