mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-02 09:37:03 +00:00
feat(model): Support Yi-34B-Chat (#837)
This commit is contained in:
@@ -135,6 +135,7 @@ At present, we have introduced several key features to showcase our current capa
|
|||||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||||
|
|
||||||
- Support API Proxy LLMs
|
- Support API Proxy LLMs
|
||||||
- [x] [ChatGPT](https://api.openai.com/)
|
- [x] [ChatGPT](https://api.openai.com/)
|
||||||
|
@@ -133,6 +133,7 @@ DB-GPT是一个开源的数据库领域大模型框架。目的是构建大模
|
|||||||
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
- [openchat_3.5](https://huggingface.co/openchat/openchat_3.5)
|
||||||
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
- [zephyr-7b-alpha](https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha)
|
||||||
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
- [mistral-7b-instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||||
|
- [Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)
|
||||||
|
|
||||||
- 支持在线代理模型
|
- 支持在线代理模型
|
||||||
- [x] [ChatGPT](https://api.openai.com/)
|
- [x] [ChatGPT](https://api.openai.com/)
|
||||||
|
@@ -126,6 +126,9 @@ LLM_MODEL_CONFIG = {
|
|||||||
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
|
"xwin-lm-13b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-13B-V0.1"),
|
||||||
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
|
# https://huggingface.co/Xwin-LM/Xwin-LM-70B-V0.1
|
||||||
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
"xwin-lm-70b-v0.1": os.path.join(MODEL_PATH, "Xwin-LM-70B-V0.1"),
|
||||||
|
# https://huggingface.co/01-ai/Yi-34B-Chat
|
||||||
|
"yi-34b-chat": os.path.join(MODEL_PATH, "Yi-34B-Chat"),
|
||||||
|
"yi-6b-chat": os.path.join(MODEL_PATH, "Yi-6B-Chat"),
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBEDDING_MODEL_CONFIG = {
|
EMBEDDING_MODEL_CONFIG = {
|
||||||
|
@@ -253,6 +253,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
params,
|
params,
|
||||||
self.model_name,
|
self.model_name,
|
||||||
self.model_path,
|
self.model_path,
|
||||||
|
self.tokenizer,
|
||||||
prompt_template=self.ml.prompt_template,
|
prompt_template=self.ml.prompt_template,
|
||||||
)
|
)
|
||||||
stream_type = ""
|
stream_type = ""
|
||||||
@@ -269,7 +270,9 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
self.model, self.model_path
|
self.model, self.model_path
|
||||||
)
|
)
|
||||||
str_prompt = params.get("prompt")
|
str_prompt = params.get("prompt")
|
||||||
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
|
print(
|
||||||
|
f"llm_adapter: {str(self.llm_adapter)}\n\nmodel prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n"
|
||||||
|
)
|
||||||
|
|
||||||
generate_stream_func_str_name = "{}.{}".format(
|
generate_stream_func_str_name = "{}.{}".format(
|
||||||
generate_stream_func.__module__, generate_stream_func.__name__
|
generate_stream_func.__module__, generate_stream_func.__name__
|
||||||
|
54
pilot/model/llm_out/hf_chat_llm.py
Normal file
54
pilot/model/llm_out/hf_chat_llm.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from threading import Thread
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||||
|
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def huggingface_chat_generate_stream(
|
||||||
|
model: AutoModelForCausalLM,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
params,
|
||||||
|
device,
|
||||||
|
context_len=4096,
|
||||||
|
):
|
||||||
|
prompt = params["prompt"]
|
||||||
|
temperature = float(params.get("temperature", 0.7))
|
||||||
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
echo = params.get("echo", False)
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
|
|
||||||
|
input_ids = tokenizer(prompt).input_ids
|
||||||
|
# input_ids = input_ids.to(device)
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
max_src_len = context_len
|
||||||
|
else: # truncate
|
||||||
|
max_src_len = context_len - max_new_tokens - 1
|
||||||
|
input_ids = input_ids[-max_src_len:]
|
||||||
|
input_echo_len = len(input_ids)
|
||||||
|
input_ids = torch.as_tensor([input_ids], device=device)
|
||||||
|
|
||||||
|
# messages = params["messages"]
|
||||||
|
# messages = ModelMessage.to_openai_messages(messages)
|
||||||
|
# input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
|
||||||
|
# input_ids = input_ids.to(device)
|
||||||
|
|
||||||
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer, skip_prompt=not echo, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
generate_kwargs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"max_length": context_len,
|
||||||
|
"temperature": temperature,
|
||||||
|
"streamer": streamer,
|
||||||
|
}
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||||
|
thread.start()
|
||||||
|
out = ""
|
||||||
|
for new_text in streamer:
|
||||||
|
out += new_text
|
||||||
|
yield out
|
@@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
|
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING, Any, Optional
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@@ -41,6 +41,7 @@ logger = logging.getLogger(__name__)
|
|||||||
thread_local = threading.local()
|
thread_local = threading.local()
|
||||||
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
_OLD_MODELS = [
|
_OLD_MODELS = [
|
||||||
"llama-cpp",
|
"llama-cpp",
|
||||||
"proxyllm",
|
"proxyllm",
|
||||||
@@ -51,6 +52,14 @@ _OLD_MODELS = [
|
|||||||
"codellama-13b",
|
"codellama-13b",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_NEW_HF_CHAT_MODELS = [
|
||||||
|
"yi-34b",
|
||||||
|
"yi-6b",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The implementation of some models in fastchat will affect the DB-GPT loading model and will be temporarily added to the blacklist.
|
||||||
|
_BLACK_LIST_MODLE_PROMPT = ["OpenHermes-2.5-Mistral-7B"]
|
||||||
|
|
||||||
|
|
||||||
class LLMModelAdaper:
|
class LLMModelAdaper:
|
||||||
"""New Adapter for DB-GPT LLM models"""
|
"""New Adapter for DB-GPT LLM models"""
|
||||||
@@ -99,26 +108,25 @@ class LLMModelAdaper:
|
|||||||
"""Get the default conv template"""
|
"""Get the default conv template"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def model_adaptation(
|
def get_str_prompt(
|
||||||
self,
|
self,
|
||||||
params: Dict,
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_prompt_with_template(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
model_context: Dict,
|
||||||
prompt_template: str = None,
|
prompt_template: str = None,
|
||||||
) -> Tuple[Dict, Dict]:
|
):
|
||||||
"""Params adaptation"""
|
|
||||||
conv = self.get_default_conv_template(model_name, model_path)
|
conv = self.get_default_conv_template(model_name, model_path)
|
||||||
messages = params.get("messages")
|
|
||||||
# Some model scontext to dbgpt server
|
|
||||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
# Dict message to ModelMessage
|
|
||||||
messages = [
|
|
||||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
|
||||||
for m in messages
|
|
||||||
]
|
|
||||||
params["messages"] = messages
|
|
||||||
|
|
||||||
if prompt_template:
|
if prompt_template:
|
||||||
logger.info(f"Use prompt template {prompt_template} from config")
|
logger.info(f"Use prompt template {prompt_template} from config")
|
||||||
@@ -128,7 +136,7 @@ class LLMModelAdaper:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||||
)
|
)
|
||||||
return params, model_context
|
return None, None, None
|
||||||
|
|
||||||
conv = conv.copy()
|
conv = conv.copy()
|
||||||
system_messages = []
|
system_messages = []
|
||||||
@@ -180,6 +188,41 @@ class LLMModelAdaper:
|
|||||||
# Add a blank message for the assistant.
|
# Add a blank message for the assistant.
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
new_prompt = conv.get_prompt()
|
new_prompt = conv.get_prompt()
|
||||||
|
return new_prompt, conv.stop_str, conv.stop_token_ids
|
||||||
|
|
||||||
|
def model_adaptation(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
model_name: str,
|
||||||
|
model_path: str,
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Tuple[Dict, Dict]:
|
||||||
|
"""Params adaptation"""
|
||||||
|
messages = params.get("messages")
|
||||||
|
# Some model scontext to dbgpt server
|
||||||
|
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||||
|
if messages:
|
||||||
|
# Dict message to ModelMessage
|
||||||
|
messages = [
|
||||||
|
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
params["messages"] = messages
|
||||||
|
|
||||||
|
new_prompt = self.get_str_prompt(params, messages, tokenizer, prompt_template)
|
||||||
|
conv_stop_str, conv_stop_token_ids = None, None
|
||||||
|
if not new_prompt:
|
||||||
|
(
|
||||||
|
new_prompt,
|
||||||
|
conv_stop_str,
|
||||||
|
conv_stop_token_ids,
|
||||||
|
) = self.get_prompt_with_template(
|
||||||
|
params, messages, model_name, model_path, model_context, prompt_template
|
||||||
|
)
|
||||||
|
if not new_prompt:
|
||||||
|
return params, model_context
|
||||||
|
|
||||||
# Overwrite the original prompt
|
# Overwrite the original prompt
|
||||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||||
@@ -192,8 +235,8 @@ class LLMModelAdaper:
|
|||||||
custom_stop_token_ids = params.get("stop_token_ids")
|
custom_stop_token_ids = params.get("stop_token_ids")
|
||||||
|
|
||||||
# Prefer the value passed in from the input parameter
|
# Prefer the value passed in from the input parameter
|
||||||
params["stop"] = custom_stop or conv.stop_str
|
params["stop"] = custom_stop or conv_stop_str
|
||||||
params["stop_token_ids"] = custom_stop_token_ids or conv.stop_token_ids
|
params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids
|
||||||
|
|
||||||
return params, model_context
|
return params, model_context
|
||||||
|
|
||||||
@@ -270,6 +313,69 @@ class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NewHFChatModelAdapter(LLMModelAdaper):
|
||||||
|
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
|
try:
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import depend python package "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
) from exc
|
||||||
|
if not transformers.__version__ >= "4.34.0":
|
||||||
|
raise ValueError(
|
||||||
|
"Current model (Load by HFNewChatAdapter) require transformers.__version__>=4.34.0"
|
||||||
|
)
|
||||||
|
revision = from_pretrained_kwargs.get("revision", "main")
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
use_fast=self.use_fast_tokenizer,
|
||||||
|
revision=revision,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path, use_fast=False, revision=revision, trust_remote_code=True
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
except NameError:
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
||||||
|
)
|
||||||
|
# tokenizer.use_default_system_prompt = False
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def get_generate_stream_function(self, model, model_path: str):
|
||||||
|
"""Get the generate stream function of the model"""
|
||||||
|
from pilot.model.llm_out.hf_chat_llm import huggingface_chat_generate_stream
|
||||||
|
|
||||||
|
return huggingface_chat_generate_stream
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if not tokenizer:
|
||||||
|
raise ValueError("tokenizer is is None")
|
||||||
|
tokenizer: AutoTokenizer = tokenizer
|
||||||
|
|
||||||
|
messages = ModelMessage.to_openai_messages(messages)
|
||||||
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return str_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_conv_template(name: str) -> "Conversation":
|
def get_conv_template(name: str) -> "Conversation":
|
||||||
"""Get a conversation template."""
|
"""Get a conversation template."""
|
||||||
from fastchat.conversation import get_conv_template
|
from fastchat.conversation import get_conv_template
|
||||||
@@ -298,6 +404,11 @@ def get_llm_model_adapter(
|
|||||||
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
||||||
return VLLMModelAdaperWrapper()
|
return VLLMModelAdaperWrapper()
|
||||||
|
|
||||||
|
use_new_hf_chat_models = any(m in model_name.lower() for m in _NEW_HF_CHAT_MODELS)
|
||||||
|
if use_new_hf_chat_models:
|
||||||
|
logger.info(f"Current model {model_name} use NewHFChatModelAdapter")
|
||||||
|
return NewHFChatModelAdapter()
|
||||||
|
|
||||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||||
if use_fastchat and not must_use_old:
|
if use_fastchat and not must_use_old:
|
||||||
logger.info("Use fastcat adapter")
|
logger.info("Use fastcat adapter")
|
||||||
@@ -334,6 +445,7 @@ def _get_fastchat_model_adapter(
|
|||||||
if use_fastchat_monkey_patch:
|
if use_fastchat_monkey_patch:
|
||||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||||
thread_local.model_name = model_name
|
thread_local.model_name = model_name
|
||||||
|
_remove_black_list_model_of_fastchat()
|
||||||
if caller:
|
if caller:
|
||||||
return caller(model_path)
|
return caller(model_path)
|
||||||
finally:
|
finally:
|
||||||
@@ -377,6 +489,24 @@ def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _remove_black_list_model_of_fastchat():
|
||||||
|
from fastchat.model.model_adapter import model_adapters
|
||||||
|
|
||||||
|
black_list_models = []
|
||||||
|
for adapter in model_adapters:
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
adapter.get_default_conv_template("/data/not_exist_model_path").name
|
||||||
|
in _BLACK_LIST_MODLE_PROMPT
|
||||||
|
):
|
||||||
|
black_list_models.append(adapter)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
for adapter in black_list_models:
|
||||||
|
model_adapters.remove(adapter)
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
|
@@ -113,6 +113,34 @@ class ModelMessage(BaseModel):
|
|||||||
raise ValueError(f"Unknown role: {msg_role}")
|
raise ValueError(f"Unknown role: {msg_role}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||||
|
"""Convert to OpenAI message format and
|
||||||
|
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||||
|
"""
|
||||||
|
history = []
|
||||||
|
# Add history conversation
|
||||||
|
for message in messages:
|
||||||
|
if message.role == ModelMessageRoleType.HUMAN:
|
||||||
|
history.append({"role": "user", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||||
|
history.append({"role": "system", "content": message.content})
|
||||||
|
elif message.role == ModelMessageRoleType.AI:
|
||||||
|
history.append({"role": "assistant", "content": message.content})
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
# Move the last user's information to the end
|
||||||
|
temp_his = history[::-1]
|
||||||
|
last_user_input = None
|
||||||
|
for m in temp_his:
|
||||||
|
if m["role"] == "user":
|
||||||
|
last_user_input = m
|
||||||
|
break
|
||||||
|
if last_user_input:
|
||||||
|
history.remove(last_user_input)
|
||||||
|
history.append(last_user_input)
|
||||||
|
return history
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||||
return list(map(lambda m: m.dict(), messages))
|
return list(map(lambda m: m.dict(), messages))
|
||||||
|
17
setup.py
17
setup.py
@@ -18,6 +18,10 @@ BUILD_NO_CACHE = os.getenv("BUILD_NO_CACHE", "true").lower() == "true"
|
|||||||
LLAMA_CPP_GPU_ACCELERATION = (
|
LLAMA_CPP_GPU_ACCELERATION = (
|
||||||
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
os.getenv("LLAMA_CPP_GPU_ACCELERATION", "true").lower() == "true"
|
||||||
)
|
)
|
||||||
|
BUILD_FROM_SOURCE = os.getenv("BUILD_FROM_SOURCE", "false").lower() == "true"
|
||||||
|
BUILD_FROM_SOURCE_URL_FAST_CHAT = os.getenv(
|
||||||
|
"BUILD_FROM_SOURCE_URL_FAST_CHAT", "git+https://github.com/lm-sys/FastChat.git"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements(file_name: str) -> List[str]:
|
def parse_requirements(file_name: str) -> List[str]:
|
||||||
@@ -298,7 +302,6 @@ def core_requires():
|
|||||||
]
|
]
|
||||||
|
|
||||||
setup_spec.extras["framework"] = [
|
setup_spec.extras["framework"] = [
|
||||||
"fschat",
|
|
||||||
"coloredlogs",
|
"coloredlogs",
|
||||||
"httpx",
|
"httpx",
|
||||||
"sqlparse==0.4.4",
|
"sqlparse==0.4.4",
|
||||||
@@ -315,7 +318,8 @@ def core_requires():
|
|||||||
"duckdb-engine",
|
"duckdb-engine",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
# TODO move transformers to default
|
# TODO move transformers to default
|
||||||
"transformers>=4.31.0",
|
# "transformers>=4.31.0",
|
||||||
|
"transformers>=4.34.0",
|
||||||
"alembic==1.12.0",
|
"alembic==1.12.0",
|
||||||
# for excel
|
# for excel
|
||||||
"openpyxl==3.1.2",
|
"openpyxl==3.1.2",
|
||||||
@@ -324,6 +328,12 @@ def core_requires():
|
|||||||
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
# for cache, TODO pympler has not been updated for a long time and needs to find a new toolkit.
|
||||||
"pympler",
|
"pympler",
|
||||||
]
|
]
|
||||||
|
if BUILD_FROM_SOURCE:
|
||||||
|
setup_spec.extras["framework"].append(
|
||||||
|
f"fschat @ {BUILD_FROM_SOURCE_URL_FAST_CHAT}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setup_spec.extras["framework"].append("fschat")
|
||||||
|
|
||||||
|
|
||||||
def knowledge_requires():
|
def knowledge_requires():
|
||||||
@@ -426,7 +436,8 @@ def default_requires():
|
|||||||
pip install "db-gpt[default]"
|
pip install "db-gpt[default]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["default"] = [
|
setup_spec.extras["default"] = [
|
||||||
"tokenizers==0.13.3",
|
# "tokenizers==0.13.3",
|
||||||
|
"tokenizers>=0.14",
|
||||||
"accelerate>=0.20.3",
|
"accelerate>=0.20.3",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
"protobuf==3.20.3",
|
"protobuf==3.20.3",
|
||||||
|
Reference in New Issue
Block a user