mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 23:28:35 +00:00
commit
948a93be32
@ -34,4 +34,4 @@ provider = "hf"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
path = "models/BAAI/glm-4-9b-chat-hf"
|
||||
path = "models/BAAI/bge-large-zh-v1.5"
|
||||
|
@ -285,8 +285,23 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
||||
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
||||
|
||||
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
|
||||
|
||||
```bash
|
||||
# Use uv to install dependencies needed for llama-cpp
|
||||
# Install core dependencies and select desired extensions
|
||||
CMAKE_ARGS="-DGGML_CUDA=ON" uv sync --all-packages \
|
||||
--extra "base" \
|
||||
--extra "llama_cpp" \
|
||||
--extra "rag" \
|
||||
--extra "storage_chromadb" \
|
||||
--extra "quant_bnb" \
|
||||
--extra "dbgpts"
|
||||
```
|
||||
|
||||
Otherwise, run the following command to install dependencies without CUDA support.
|
||||
```bash
|
||||
# Use uv to install dependencies needed for llama-cpp
|
||||
# Install core dependencies and select desired extensions
|
||||
|
4
packages/dbgpt-accelerator/dbgpt-acc-auto/README.md
Normal file
4
packages/dbgpt-accelerator/dbgpt-acc-auto/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# DB-GPT Accelerator Module
|
||||
|
||||
Building across multiple platforms and hardware is complex, and the DB-GPT Accelerator aims to provide compatibility handling for this, offering as consistent an interface as possible for other core models.
|
||||
|
@ -1,5 +1,5 @@
|
||||
[project]
|
||||
name = "dbgpt-accelerator"
|
||||
name = "dbgpt-acc-auto"
|
||||
version = "0.7.0"
|
||||
description = "Add your description here"
|
||||
authors = [
|
||||
@ -16,21 +16,6 @@ Documentation = "http://docs.dbgpt.cn/docs/overview"
|
||||
Repository = "https://github.com/eosphoros-ai/DB-GPT.git"
|
||||
Issues = "https://github.com/eosphoros-ai/DB-GPT/issues"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/dbgpt_accelerator"]
|
||||
exclude = [
|
||||
"src/dbgpt_accelerator/**/tests",
|
||||
"src/dbgpt_accelerator/**/tests/*",
|
||||
"src/dbgpt_accelerator/tests",
|
||||
"src/dbgpt_accelerator/tests/*",
|
||||
"src/dbgpt_accelerator/**/examples",
|
||||
"src/dbgpt_accelerator/**/examples/*"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# Auto install dependencies
|
||||
auto = [
|
||||
@ -76,10 +61,10 @@ vllm = [
|
||||
# Just support GPU version on Linux
|
||||
"vllm>=0.7.0; sys_platform == 'linux'",
|
||||
]
|
||||
#vllm_pascal = [
|
||||
# vllm_pascal = [
|
||||
# # https://github.com/sasha0552/pascal-pkgs-ci
|
||||
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
||||
#]
|
||||
# ]
|
||||
quant_bnb = [
|
||||
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
|
||||
"accelerate"
|
||||
@ -103,6 +88,10 @@ quant_gptq = [
|
||||
"optimum",
|
||||
"auto-gptq",
|
||||
]
|
||||
flash_attn = [
|
||||
# "torch>=2.2.1",
|
||||
"dbgpt-acc-flash-attn"
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
auto = [
|
@ -0,0 +1,3 @@
|
||||
# DB-GPT-Accelerator for Flash Attention
|
||||
|
||||
Wrapper for the Flash Attention module in the DB-GPT-Accelerator.
|
@ -0,0 +1,24 @@
|
||||
# Install the flash-attn package for uv
|
||||
# https://github.com/astral-sh/uv/issues/2252#issuecomment-2624150395
|
||||
[project]
|
||||
name = "dbgpt-acc-flash-attn"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = []
|
||||
|
||||
[dependency-groups]
|
||||
build = [
|
||||
"setuptools>=75.8.0",
|
||||
]
|
||||
direct = [
|
||||
"torch>=2.2.1",
|
||||
]
|
||||
main = [
|
||||
"flash-attn>=2.5.8",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
default-groups = ["build", "direct", "main"]
|
||||
no-build-isolation-package = ["flash-attn"]
|
@ -10,7 +10,7 @@ readme = "README.md"
|
||||
requires-python = ">= 3.10"
|
||||
|
||||
dependencies = [
|
||||
"dbgpt-accelerator",
|
||||
"dbgpt-acc-auto",
|
||||
"dbgpt",
|
||||
"dbgpt-ext",
|
||||
"dbgpt-serve",
|
||||
|
@ -102,6 +102,15 @@ class LLMDeployModelParameters(BaseDeployModelParameters, RegisterParameters):
|
||||
)
|
||||
},
|
||||
)
|
||||
reasoning_model: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Whether the model is a reasoning model. If None, it is "
|
||||
"automatically determined from model."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def real_provider_model_name(self) -> str:
|
||||
@ -202,8 +211,10 @@ class BitsandbytesQuantization(BaseHFQuantization):
|
||||
real_cls = cls
|
||||
if load_in_8bits:
|
||||
real_cls = BitsandbytesQuantization8bits
|
||||
data["type"] = BitsandbytesQuantization8bits.__type__
|
||||
if load_in_4bits:
|
||||
real_cls = BitsandbytesQuantization4bits
|
||||
data["type"] = BitsandbytesQuantization4bits.__type__
|
||||
real_data = prepare_data_func(real_cls, data)
|
||||
return real_cls(**real_data)
|
||||
|
||||
|
@ -251,6 +251,27 @@ class LLMModelAdapter(ABC):
|
||||
"""Load the model and tokenizer according to the given parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_reasoning_model(
|
||||
self,
|
||||
deploy_model_params: LLMDeployModelParameters,
|
||||
lower_model_name_or_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Whether the model is a reasoning model"""
|
||||
if (
|
||||
deploy_model_params.reasoning_model is not None
|
||||
and deploy_model_params.reasoning_model
|
||||
):
|
||||
return True
|
||||
return (
|
||||
lower_model_name_or_path
|
||||
and "deepseek" in lower_model_name_or_path
|
||||
and (
|
||||
"r1" in lower_model_name_or_path
|
||||
or "reasoning" in lower_model_name_or_path
|
||||
or "reasoner" in lower_model_name_or_path
|
||||
)
|
||||
)
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
@ -88,6 +88,15 @@ class HFLLMDeployModelParameters(LLMDeployModelParameters):
|
||||
"valid_values": ["auto", "float16", "bfloat16", "float", "float32"],
|
||||
},
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"The attention implementation, only valid in multi-GPU configuration"
|
||||
),
|
||||
"valid_values": ["flash_attention_2"],
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
|
@ -301,15 +301,15 @@ class LlamaServerParameters(LLMDeployModelParameters):
|
||||
config_dict[fd.name] = curr_config[fd.name]
|
||||
|
||||
if (
|
||||
"device" in config_dict
|
||||
and config_dict["device"] == "cuda"
|
||||
self.real_device
|
||||
and self.real_device == "cuda"
|
||||
and ("n_gpu_layers" not in config_dict or not config_dict["n_gpu_layers"])
|
||||
):
|
||||
# Set n_gpu_layers to a large number to use all layers
|
||||
logger.info("Set n_gpu_layers to a large number to use all layers")
|
||||
config_dict["n_gpu_layers"] = 1000000000
|
||||
config_dict["model_alias"] = self.name
|
||||
config_dict["model_file"] = self.path
|
||||
config_dict["model_file"] = self._resolve_root_path(self.path)
|
||||
model_file = config_dict.get("model_file")
|
||||
model_url = config_dict.get("model_url")
|
||||
model_hf_repo = config_dict.get("model_hf_repo")
|
||||
|
@ -143,6 +143,8 @@ def huggingface_loader(
|
||||
if "device_map" in kwargs and "low_cpu_mem_usage" not in kwargs:
|
||||
# Must set low_cpu_mem_usage to True when device_map is set
|
||||
kwargs["low_cpu_mem_usage"] = True
|
||||
if model_params.attn_implementation:
|
||||
kwargs["attn_implementation"] = model_params.attn_implementation
|
||||
|
||||
model, tokenizer = _hf_try_load_default_quantization_model(
|
||||
model_path, llm_adapter, device, num_gpus, model_params, kwargs
|
||||
|
@ -62,6 +62,9 @@ class VLLMDeployModelParameters(LLMDeployModelParameters):
|
||||
model = data.get("path", None)
|
||||
if not model:
|
||||
model = data.get("name", None)
|
||||
else:
|
||||
# Path is specified, so we use it as the model
|
||||
model = self._resolve_root_path(model)
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"Model is required, please specify the model path or name."
|
||||
|
@ -421,6 +421,10 @@ class DefaultModelWorker(ModelWorker):
|
||||
span_params["messages"] = list(
|
||||
map(lambda m: m.dict(), span_params["messages"])
|
||||
)
|
||||
if self.llm_adapter.is_reasoning_model(
|
||||
self._model_params, self.model_name.lower()
|
||||
):
|
||||
params["is_reasoning_model"] = True
|
||||
|
||||
metadata = {
|
||||
"is_async_func": self.support_async(),
|
||||
|
@ -613,39 +613,41 @@ class LocalWorkerManager(WorkerManager):
|
||||
)
|
||||
|
||||
async def _start_all_worker(
|
||||
self, apply_req: WorkerApplyRequest
|
||||
self, apply_req: WorkerApplyRequest, parallel_num: int = 1
|
||||
) -> WorkerApplyOutput:
|
||||
from httpx import TimeoutException, TransportError
|
||||
|
||||
# TODO avoid start twice
|
||||
start_time = time.time()
|
||||
logger.info(f"Begin start all worker, apply_req: {apply_req}")
|
||||
semaphore = asyncio.Semaphore(parallel_num)
|
||||
|
||||
async def _start_worker(worker_run_data: WorkerRunData):
|
||||
_start_time = time.time()
|
||||
info = worker_run_data._to_print_key()
|
||||
out = WorkerApplyOutput("")
|
||||
try:
|
||||
await self.run_blocking_func(
|
||||
worker_run_data.worker.start,
|
||||
worker_run_data.command_args,
|
||||
)
|
||||
worker_run_data.stop_event.clear()
|
||||
if worker_run_data.worker_params.register and self.register_func:
|
||||
# Register worker to controller
|
||||
await self.register_func(worker_run_data)
|
||||
if (
|
||||
worker_run_data.worker_params.send_heartbeat
|
||||
and self.send_heartbeat_func
|
||||
):
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(
|
||||
worker_run_data,
|
||||
worker_run_data.worker_params.heartbeat_interval,
|
||||
self.send_heartbeat_func,
|
||||
async with semaphore:
|
||||
await self.run_blocking_func(
|
||||
worker_run_data.worker.start,
|
||||
worker_run_data.command_args,
|
||||
)
|
||||
worker_run_data.stop_event.clear()
|
||||
if worker_run_data.worker_params.register and self.register_func:
|
||||
# Register worker to controller
|
||||
await self.register_func(worker_run_data)
|
||||
if (
|
||||
worker_run_data.worker_params.send_heartbeat
|
||||
and self.send_heartbeat_func
|
||||
):
|
||||
asyncio.create_task(
|
||||
_async_heartbeat_sender(
|
||||
worker_run_data,
|
||||
worker_run_data.worker_params.heartbeat_interval,
|
||||
self.send_heartbeat_func,
|
||||
)
|
||||
)
|
||||
)
|
||||
out.message = f"{info} start successfully"
|
||||
out.message = f"{info} start successfully"
|
||||
except TimeoutException:
|
||||
out.success = False
|
||||
out.message = (
|
||||
|
@ -13,7 +13,11 @@ from dbgpt.core import ModelOutput
|
||||
from dbgpt.model.adapter.llama_cpp_py_adapter import LlamaCppModelParameters
|
||||
from dbgpt.model.utils.llm_utils import parse_model_request
|
||||
|
||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
ParsedChatMessage,
|
||||
parse_chat_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -113,6 +117,8 @@ class LlamaCppModel:
|
||||
messages = request.to_common_messages()
|
||||
repetition_penalty = float(params.get("repetition_penalty", 1.1))
|
||||
top_k = int(params.get("top_k", -1)) # -1 means disable
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
# Handle truncation
|
||||
completion_chunks = self.model.create_chat_completion(
|
||||
messages=messages,
|
||||
@ -129,6 +135,7 @@ class LlamaCppModel:
|
||||
usage = None
|
||||
msg = ParsedChatMessage()
|
||||
finish_reason: Optional[str] = None
|
||||
is_first = True
|
||||
for r in completion_chunks:
|
||||
if not r.get("choices"):
|
||||
continue
|
||||
@ -136,11 +143,16 @@ class LlamaCppModel:
|
||||
if delta.get("content") is not None:
|
||||
content = delta["content"]
|
||||
text += content
|
||||
msg, _ = parse_chat_message(
|
||||
content,
|
||||
extract_reasoning=True,
|
||||
is_streaming=True,
|
||||
streaming_state=msg,
|
||||
if (
|
||||
is_reasoning_model
|
||||
and not text.startswith(think_start_token)
|
||||
and is_first
|
||||
):
|
||||
text = think_start_token + "\n" + text
|
||||
is_first = False
|
||||
msg = parse_chat_message(
|
||||
text,
|
||||
extract_reasoning=is_reasoning_model,
|
||||
)
|
||||
finish_reason = delta.get("finish_reason")
|
||||
if text:
|
||||
|
@ -17,7 +17,11 @@ from typing import Dict, Optional
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
ParsedChatMessage,
|
||||
parse_chat_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -78,7 +82,10 @@ def chat_generate_stream(
|
||||
):
|
||||
req = _build_chat_completion_request(params, stream=True)
|
||||
text = ""
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
msg = ParsedChatMessage()
|
||||
is_first = True
|
||||
for r in model.stream_chat_completion(req):
|
||||
if len(r.choices) == 0:
|
||||
continue
|
||||
@ -86,13 +93,17 @@ def chat_generate_stream(
|
||||
if r.choices[0] is not None and r.choices[0].delta is None:
|
||||
continue
|
||||
content = r.choices[0].delta.content
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
text += content
|
||||
if is_reasoning_model and not text.startswith(think_start_token) and is_first:
|
||||
text = think_start_token + "\n" + text
|
||||
is_first = False
|
||||
|
||||
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
|
||||
finish_reason = _parse_finish_reason(r.choices[0].finish_reason)
|
||||
|
||||
if content is not None:
|
||||
text += content
|
||||
msg, _ = parse_chat_message(
|
||||
content, extract_reasoning=True, is_streaming=True, streaming_state=msg
|
||||
)
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
msg.reasoning_content,
|
||||
|
@ -6,7 +6,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.parse_utils import ParsedChatMessage, parse_chat_message
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
ParsedChatMessage,
|
||||
parse_chat_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,6 +31,8 @@ def huggingface_chat_generate_stream(
|
||||
stop_token_ids = params.get("stop_token_ids", [])
|
||||
do_sample = params.get("do_sample", True)
|
||||
custom_stop_words = params.get("custom_stop_words", [])
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
# input_ids = input_ids.to(device)
|
||||
@ -65,15 +71,23 @@ def huggingface_chat_generate_stream(
|
||||
text = ""
|
||||
usage = None
|
||||
msg = ParsedChatMessage()
|
||||
is_first = True
|
||||
for new_text in streamer:
|
||||
text += new_text
|
||||
msg, _ = parse_chat_message(
|
||||
new_text, extract_reasoning=True, is_streaming=True, streaming_state=msg
|
||||
)
|
||||
if custom_stop_words:
|
||||
for stop_word in custom_stop_words:
|
||||
if text.endswith(stop_word):
|
||||
text = text[: -len(stop_word)]
|
||||
|
||||
if (
|
||||
prompt.rstrip().endswith(think_start_token)
|
||||
and is_reasoning_model
|
||||
and is_first
|
||||
):
|
||||
text = think_start_token + "\n" + text
|
||||
is_first = False
|
||||
|
||||
msg = parse_chat_message(text, extract_reasoning=is_reasoning_model)
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
msg.reasoning_content,
|
||||
|
@ -31,6 +31,7 @@ async def generate_stream(
|
||||
best_of = params.get("best_of", None)
|
||||
stop_str = params.get("stop", None)
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
# think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
@ -104,11 +105,11 @@ async def generate_stream(
|
||||
)
|
||||
if text_outputs:
|
||||
# Tempora
|
||||
if prompt.rstrip().endswith(think_start_token):
|
||||
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:
|
||||
text_outputs = think_start_token + "\n" + text_outputs
|
||||
msg = parse_chat_message(
|
||||
text_outputs,
|
||||
extract_reasoning=True,
|
||||
extract_reasoning=is_reasoning_model,
|
||||
)
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
|
@ -9,13 +9,13 @@ from dbgpt.core.awel.flow import (
|
||||
ResourceCategory,
|
||||
auto_register_resource,
|
||||
)
|
||||
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.proxy.base import (
|
||||
AsyncGenerateStreamFunction,
|
||||
GenerateStreamFunction,
|
||||
ProxyLLMClient,
|
||||
register_proxy_model_adapter,
|
||||
)
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAICompatibleDeployModelParameters
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||
show_in_ui=False,
|
||||
)
|
||||
@dataclass
|
||||
class OllamaDeployModelParameters(OpenAICompatibleDeployModelParameters):
|
||||
class OllamaDeployModelParameters(LLMDeployModelParameters):
|
||||
"""Deploy model parameters for Ollama."""
|
||||
|
||||
provider: str = "proxy/ollama"
|
||||
|
@ -128,13 +128,13 @@ def process_streaming_chunk(
|
||||
if end_marker in remaining_chunk:
|
||||
end_idx = remaining_chunk.find(end_marker)
|
||||
# Output reasoning content event
|
||||
if end_idx > 0:
|
||||
reasoning_part = remaining_chunk[:end_idx]
|
||||
events.append(
|
||||
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
||||
)
|
||||
# Append reasoning content instead of replacing
|
||||
msg.reasoning_content += reasoning_part
|
||||
# if end_idx > 0:
|
||||
reasoning_part = remaining_chunk[:end_idx]
|
||||
events.append(
|
||||
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
||||
)
|
||||
# Append reasoning content instead of replacing
|
||||
msg.reasoning_content += reasoning_part
|
||||
|
||||
# Output reasoning end event
|
||||
events.append(StreamingEvent(type="reasoning_end", content=""))
|
||||
@ -220,6 +220,49 @@ def process_streaming_chunk(
|
||||
remaining_chunk = ""
|
||||
continue
|
||||
|
||||
# Check for reasoning end markers without matching start markers
|
||||
# This is the special case to handle
|
||||
found_end_marker = False
|
||||
for pattern in reasoning_patterns:
|
||||
start_marker = pattern["start"]
|
||||
end_marker = pattern["end"]
|
||||
if end_marker in remaining_chunk and not state["in_reasoning"]:
|
||||
end_idx = remaining_chunk.find(end_marker)
|
||||
start_idx = 0
|
||||
if start_marker in remaining_chunk:
|
||||
start_idx = remaining_chunk.find(start_marker) + len(start_marker)
|
||||
|
||||
# This is content that should be treated as reasoning but didn't have a
|
||||
# start tag
|
||||
# if end_idx > 0:
|
||||
reasoning_part = remaining_chunk[start_idx:end_idx]
|
||||
# Clear regular content
|
||||
reasoning_part = msg.content + reasoning_part
|
||||
msg.content = ""
|
||||
|
||||
# First, emit a reasoning_start event
|
||||
events.append(StreamingEvent(type="reasoning_start", content=""))
|
||||
|
||||
# Then emit the content as reasoning content
|
||||
events.append(
|
||||
StreamingEvent(type="reasoning_content", content=reasoning_part)
|
||||
)
|
||||
|
||||
# Add to reasoning content
|
||||
msg.reasoning_content += reasoning_part
|
||||
|
||||
# Emit the reasoning_end event
|
||||
events.append(StreamingEvent(type="reasoning_end", content=""))
|
||||
# Move past the end marker
|
||||
remaining_chunk = remaining_chunk[end_idx + len(end_marker) :]
|
||||
found_end_marker = True
|
||||
state["reasoning_pattern"] = None
|
||||
break
|
||||
|
||||
# If we found an end marker, continue to the next iteration
|
||||
if found_end_marker:
|
||||
continue
|
||||
|
||||
# Check for reasoning start markers
|
||||
reasoning_start_found = False
|
||||
for pattern in reasoning_patterns:
|
||||
@ -228,10 +271,10 @@ def process_streaming_chunk(
|
||||
start_idx = remaining_chunk.find(start_marker)
|
||||
|
||||
# Output regular content before the marker
|
||||
if start_idx > 0:
|
||||
content_part = remaining_chunk[:start_idx]
|
||||
events.append(StreamingEvent(type="content", content=content_part))
|
||||
msg.content += content_part
|
||||
# if start_idx > 0:
|
||||
content_part = remaining_chunk[:start_idx]
|
||||
events.append(StreamingEvent(type="content", content=content_part))
|
||||
msg.content += content_part
|
||||
|
||||
# Output reasoning start event
|
||||
events.append(StreamingEvent(type="reasoning_start", content=""))
|
||||
@ -257,12 +300,10 @@ def process_streaming_chunk(
|
||||
start_idx = remaining_chunk.find(start_marker)
|
||||
|
||||
# Output regular content before the marker
|
||||
if start_idx > 0:
|
||||
content_part = remaining_chunk[:start_idx]
|
||||
events.append(
|
||||
StreamingEvent(type="content", content=content_part)
|
||||
)
|
||||
msg.content += content_part
|
||||
# if start_idx > 0:
|
||||
content_part = remaining_chunk[:start_idx]
|
||||
events.append(StreamingEvent(type="content", content=content_part))
|
||||
msg.content += content_part
|
||||
|
||||
# Output tool call start event
|
||||
events.append(StreamingEvent(type="tool_call_start", content=""))
|
||||
@ -355,6 +396,7 @@ def parse_chat_message(
|
||||
reasoning_content = ""
|
||||
content = input_text
|
||||
|
||||
# First check for the normal case with proper start and end markers
|
||||
for pattern in reasoning_patterns:
|
||||
start_marker = pattern["start"]
|
||||
end_marker = pattern["end"]
|
||||
@ -372,6 +414,44 @@ def parse_chat_message(
|
||||
content = content[:start_idx] + content[end_idx + len(end_marker) :]
|
||||
break
|
||||
|
||||
# If no reasoning content was found with the standard pattern, check for the
|
||||
# special case
|
||||
# where content starts with reasoning but has no start marker
|
||||
if not reasoning_content:
|
||||
for pattern in reasoning_patterns:
|
||||
start_marker = pattern["start"]
|
||||
end_marker = pattern["end"]
|
||||
|
||||
if end_marker in content:
|
||||
# Check if this is at the beginning of the content or
|
||||
# if there's no matching start marker before it
|
||||
end_idx = content.find(end_marker)
|
||||
start_marker = pattern["start"]
|
||||
start_idx = content.find(start_marker)
|
||||
|
||||
# If no start marker or end marker appears before start marker
|
||||
if start_idx == -1 or end_idx < start_idx:
|
||||
# This is our special case - treat the content up to the end marker
|
||||
# as reasoning
|
||||
reasoning_content = string_strip(content[:end_idx])
|
||||
|
||||
# Remove reasoning part from original content
|
||||
if extract_reasoning:
|
||||
content = content[end_idx + len(end_marker) :]
|
||||
break
|
||||
elif start_marker in content:
|
||||
# If there's a start marker but no end marker, treat the content
|
||||
# as reasoning content
|
||||
start_idx = content.find(start_marker)
|
||||
reasoning_content = string_strip(
|
||||
content[start_idx + len(start_marker) :]
|
||||
)
|
||||
|
||||
# Remove reasoning part from original content
|
||||
if extract_reasoning:
|
||||
content = ""
|
||||
break
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls_text = ""
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
|
||||
from ..parse_utils import parse_chat_message
|
||||
|
||||
|
||||
@ -357,8 +355,8 @@ def test_streaming_mode_without_tool_calls():
|
||||
|
||||
# Verify final message should include tool call markers
|
||||
assert (
|
||||
"I will search for data.<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>Search complete." # noqa
|
||||
== msg.content
|
||||
"I will search for data.<|tool▁calls▁begin|>Tool call content"
|
||||
"<|tool▁calls▁end|>Search complete." == msg.content # noqa
|
||||
)
|
||||
assert len(msg.tool_calls) == 0
|
||||
|
||||
@ -449,48 +447,6 @@ def test_incomplete_markers():
|
||||
assert "Tool content" in msg2.streaming_state.get("tool_call_text", "")
|
||||
|
||||
|
||||
def test_multiple_special_sections():
|
||||
"""Test handling multiple special sections"""
|
||||
input_text = """<think>Reasoning content 1</think>Regular content 1
|
||||
<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>
|
||||
Regular content 2<think>Reasoning content 2</think>End"""
|
||||
|
||||
result = parse_chat_message(input_text, extract_tool_calls=True)
|
||||
|
||||
# Verify only first reasoning content is extracted
|
||||
assert "Reasoning content 1" == result.reasoning_content
|
||||
assert (
|
||||
"Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
|
||||
== result.content
|
||||
)
|
||||
|
||||
# Use streaming processing to handle multiple reasoning parts
|
||||
chunks = [
|
||||
"<think>Reasoning content 1</think>Regular content 1\n",
|
||||
"<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>\n",
|
||||
"Regular content 2<think>Reasoning content 2</think>End",
|
||||
]
|
||||
|
||||
msg = None
|
||||
all_events = []
|
||||
|
||||
for chunk in chunks:
|
||||
msg, events = parse_chat_message(
|
||||
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
||||
)
|
||||
all_events.extend(events)
|
||||
|
||||
# Verify event sequence contains two reasoning sections
|
||||
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
||||
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
||||
|
||||
assert reasoning_start_counts == 2
|
||||
assert reasoning_end_counts == 2
|
||||
|
||||
# In streaming mode, reasoning content should accumulate
|
||||
assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content
|
||||
|
||||
|
||||
def test_custom_streaming_patterns():
|
||||
"""Test custom streaming pattern markers"""
|
||||
custom_reasoning = [{"start": "{{thinking}}", "end": "{{/thinking}}"}]
|
||||
@ -530,41 +486,239 @@ def test_custom_streaming_patterns():
|
||||
assert "tool_call_end" in event_types
|
||||
|
||||
|
||||
def test_alternating_reasoning_and_tool_calls():
|
||||
"""Test alternating between reasoning and tool calls in a single message"""
|
||||
# Use streaming to capture all sections
|
||||
def test_missing_start_token_non_streaming():
|
||||
"""Test parsing messages with missing start token but having end token
|
||||
(non-streaming mode)
|
||||
"""
|
||||
input_text = """Model reasoning content without start token.
|
||||
</think>
|
||||
This is the regular content part."""
|
||||
|
||||
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||
|
||||
assert "This is the regular content part." == result.content
|
||||
assert "Model reasoning content without start token." == result.reasoning_content
|
||||
|
||||
|
||||
def test_missing_start_token_streaming():
|
||||
"""Test parsing messages with missing start token but having end token
|
||||
(streaming mode)
|
||||
"""
|
||||
chunks = [
|
||||
"<think>First reasoning block</think>Content 1\n",
|
||||
"<|tool▁calls▁begin|>Tool call 1<|tool▁calls▁end|>\n",
|
||||
"Content 2<think>Second reasoning block</think>\n",
|
||||
"<|tool▁calls▁begin|>Tool call 2<|tool▁calls▁end|>\n",
|
||||
"Final content",
|
||||
"Model reasoning content ",
|
||||
"without start token.</think>",
|
||||
"This is the regular content part.",
|
||||
]
|
||||
|
||||
msg = None
|
||||
all_events = []
|
||||
|
||||
for chunk in chunks:
|
||||
msg, events = parse_chat_message(
|
||||
chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
||||
)
|
||||
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||
all_events.extend(events)
|
||||
|
||||
# Verify content is parsed correctly - note the double newlines
|
||||
assert "Content 1\n\nContent 2\n\nFinal content" == msg.content
|
||||
assert "First reasoning blockSecond reasoning block" == msg.reasoning_content
|
||||
# Verify final message - match the expected result in test
|
||||
assert "This is the regular content part." == msg.content
|
||||
assert "Model reasoning content without start token." == msg.reasoning_content
|
||||
|
||||
# Count events by type
|
||||
event_counts = {}
|
||||
for e in all_events:
|
||||
event_counts[e.type] = event_counts.get(e.type, 0) + 1
|
||||
|
||||
assert event_counts.get("reasoning_start", 0) == 2
|
||||
assert event_counts.get("reasoning_end", 0) == 2
|
||||
assert event_counts.get("tool_call_start", 0) == 2
|
||||
assert event_counts.get("tool_call_end", 0) == 2
|
||||
# Verify event sequence contains correct reasoning events
|
||||
event_types = [e.type for e in all_events]
|
||||
assert "reasoning_start" in event_types
|
||||
assert "reasoning_content" in event_types
|
||||
assert "reasoning_end" in event_types
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
pytest.main(["-v", "test_parse_utils.py"])
|
||||
def test_missing_start_token_deepseek_chinese():
|
||||
"""Test the DeepSeek example with Chinese content missing start token"""
|
||||
input_text = """您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手\
|
||||
DeepSeek-R1。有关模型和产品的详细内容请参考官方文档。
|
||||
</think>
|
||||
您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。有关模型\
|
||||
和产品的详细内容请参考官方文档。"""
|
||||
|
||||
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||
|
||||
assert (
|
||||
"您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。"
|
||||
"有关模型和产品的详细内容请参考官方文档。" == result.content
|
||||
)
|
||||
assert (
|
||||
"您好!我是由中国的深度求索(DeepSeek)公司开发的智能助手DeepSeek-R1。"
|
||||
"有关模型和产品的详细内容请参考官方文档。" == result.reasoning_content
|
||||
)
|
||||
|
||||
|
||||
def test_multiple_missing_start_tokens():
|
||||
"""Test multiple occurrences of missing start tokens in the same message"""
|
||||
input_text = """First reasoning section.
|
||||
</think>
|
||||
Some regular content.
|
||||
Second reasoning section.
|
||||
</reasoning>
|
||||
More regular content."""
|
||||
|
||||
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||
|
||||
# Note: In non-streaming mode, only the first matching reasoning content is
|
||||
# extracted
|
||||
assert (
|
||||
"Some regular content.\nSecond reasoning section.\n</reasoning>\nMore regular "
|
||||
"content." == result.content
|
||||
)
|
||||
assert "First reasoning section." == result.reasoning_content
|
||||
|
||||
# Use streaming to capture all sections
|
||||
chunks = [
|
||||
"First reasoning section.\n</think>\n",
|
||||
"Some regular content.\n",
|
||||
"Second reasoning section.\n</reasoning>\n",
|
||||
"More regular content.",
|
||||
]
|
||||
|
||||
msg = None
|
||||
all_events = []
|
||||
|
||||
for chunk in chunks:
|
||||
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||
all_events.extend(events)
|
||||
|
||||
# In streaming mode, reasoning content should match the expected format
|
||||
assert (
|
||||
"First reasoning section.\n\nSome regular content.\nSecond reasoning section.\n"
|
||||
== msg.reasoning_content
|
||||
)
|
||||
|
||||
# Verify event sequence contains two reasoning sections
|
||||
reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
||||
reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
||||
|
||||
assert reasoning_start_counts == 2
|
||||
assert reasoning_end_counts == 2
|
||||
|
||||
|
||||
def test_missing_start_token_with_tools():
|
||||
"""Test missing reasoning start token with tool calls"""
|
||||
input_text = """Analyzing user request to query weather information.
|
||||
</think>
|
||||
I'll look up the weather data for you.
|
||||
|
||||
<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{
|
||||
"location": "Beijing",
|
||||
"date": "2023-05-20"
|
||||
}
|
||||
```
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>"""
|
||||
|
||||
result = parse_chat_message(
|
||||
input_text, extract_reasoning=True, extract_tool_calls=True
|
||||
)
|
||||
|
||||
assert "I'll look up the weather data for you." in result.content
|
||||
assert (
|
||||
"Analyzing user request to query weather information."
|
||||
== result.reasoning_content
|
||||
)
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0]["name"] == "get_weather"
|
||||
assert result.tool_calls[0]["arguments"]["location"] == "Beijing"
|
||||
|
||||
|
||||
def test_mixed_language_missing_start_token():
|
||||
"""Test mixed Chinese and English content with missing start token"""
|
||||
input_text = """这是一段中英文混合的思考内容 with both languages mixed together.
|
||||
</think>
|
||||
Here's the regular content with 中文 mixed in."""
|
||||
|
||||
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||
|
||||
assert "Here's the regular content with 中文 mixed in." == result.content
|
||||
assert (
|
||||
"这是一段中英文混合的思考内容 with both languages mixed together."
|
||||
== result.reasoning_content
|
||||
)
|
||||
|
||||
|
||||
def test_streaming_mixed_language_missing_start():
|
||||
"""Test streaming mixed language content with missing start token"""
|
||||
chunks = [
|
||||
"Analysis 分析: The user needs ",
|
||||
"information about 关于天气的信息。</reasoning>",
|
||||
"I'll provide weather information 我将提供天气信息。",
|
||||
]
|
||||
|
||||
msg = None
|
||||
all_events = []
|
||||
|
||||
for chunk in chunks:
|
||||
msg, events = parse_chat_message(chunk, is_streaming=True, streaming_state=msg)
|
||||
all_events.extend(events)
|
||||
|
||||
# Verify final message includes mixed language content properly parsed
|
||||
assert "I'll provide weather information 我将提供天气信息。" == msg.content
|
||||
assert (
|
||||
"Analysis 分析: The user needs information about 关于天气的信息。"
|
||||
== msg.reasoning_content
|
||||
)
|
||||
|
||||
# Verify events sequence
|
||||
reasoning_events = [e for e in all_events if e.type.startswith("reasoning_")]
|
||||
assert len(reasoning_events) >= 3 # At least start, content, and end events
|
||||
|
||||
|
||||
def test_chinese_pattern_missing_start():
|
||||
"""Test Chinese pattern with missing start token"""
|
||||
input_text = """这里是模型的思考内容,但是没有开始标记。
|
||||
</思考>
|
||||
这是正常的响应内容。"""
|
||||
|
||||
result = parse_chat_message(input_text, extract_reasoning=True)
|
||||
|
||||
assert "这是正常的响应内容。" == result.content
|
||||
assert "这里是模型的思考内容,但是没有开始标记。" == result.reasoning_content
|
||||
|
||||
|
||||
#
|
||||
# def test_multiple_special_sections():
|
||||
# """Test handling multiple special sections"""
|
||||
# input_text = """<think>Reasoning content 1</think>Regular content 1
|
||||
# <|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>
|
||||
# Regular content 2<think>Reasoning content 2</think>End"""
|
||||
#
|
||||
# result = parse_chat_message(input_text, extract_tool_calls=True)
|
||||
#
|
||||
# # Verify only first reasoning content is extracted
|
||||
# assert "Reasoning content 1" == result.reasoning_content
|
||||
# assert (
|
||||
# "Regular content 1\n\nRegular content 2<think>Reasoning content 2</think>End"
|
||||
# == result.content
|
||||
# )
|
||||
#
|
||||
# # Use streaming processing to handle multiple reasoning parts
|
||||
# chunks = [
|
||||
# "<think>Reasoning content 1</think>Regular content 1\n",
|
||||
# "<|tool▁calls▁begin|>Tool call content<|tool▁calls▁end|>\n",
|
||||
# "Regular content 2<think>Reasoning content 2</think>End",
|
||||
# ]
|
||||
#
|
||||
# msg = None
|
||||
# all_events = []
|
||||
#
|
||||
# for chunk in chunks:
|
||||
# msg, events = parse_chat_message(
|
||||
# chunk, is_streaming=True, streaming_state=msg, extract_tool_calls=True
|
||||
# )
|
||||
# all_events.extend(events)
|
||||
#
|
||||
# # Verify event sequence contains two reasoning sections
|
||||
# reasoning_start_counts = sum(1 for e in all_events if e.type == "reasoning_start")
|
||||
# reasoning_end_counts = sum(1 for e in all_events if e.type == "reasoning_end")
|
||||
#
|
||||
# assert reasoning_start_counts == 2
|
||||
# assert reasoning_end_counts == 2
|
||||
#
|
||||
# # In streaming mode, reasoning content should match the expected format
|
||||
# assert "Reasoning content 1Reasoning content 2" == msg.reasoning_content
|
||||
|
@ -113,7 +113,7 @@ class HFEmbeddingDeployModelParameters(EmbeddingDeployModelParameters):
|
||||
@property
|
||||
def real_provider_model_name(self) -> str:
|
||||
"""Get the real provider model name."""
|
||||
return self.path or self.name
|
||||
return self.real_model_path or self.name
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
|
@ -40,8 +40,7 @@ class ClickhouseParameters(BaseDatasourceParameters):
|
||||
user: str = field(metadata={"help": _("Database user to connect")})
|
||||
database: str = field(metadata={"help": _("Database name")})
|
||||
engine: str = field(
|
||||
default="MergeTree",
|
||||
metadata={"help": _("Storage engine, e.g., MergeTree")}
|
||||
default="MergeTree", metadata={"help": _("Storage engine, e.g., MergeTree")}
|
||||
)
|
||||
password: str = field(
|
||||
default="${env:DBGPT_DB_PASSWORD}",
|
||||
|
@ -13,16 +13,22 @@ readme = "README.md"
|
||||
requires-python = ">= 3.10"
|
||||
|
||||
[tool.uv.sources]
|
||||
dbgpt-accelerator = { workspace = true }
|
||||
dbgpt = { workspace = true }
|
||||
dbgpt-client = { workspace = true }
|
||||
dbgpt-ext = { workspace = true }
|
||||
dbgpt-serve = { workspace = true }
|
||||
dbgpt-app = { workspace = true }
|
||||
dbgpt-acc-auto = { workspace = true }
|
||||
dbgpt-acc-flash-attn = { workspace = true }
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"packages/dbgpt-*"
|
||||
"packages/dbgpt-app",
|
||||
"packages/dbgpt-client",
|
||||
"packages/dbgpt-core",
|
||||
"packages/dbgpt-ext",
|
||||
"packages/dbgpt-serve",
|
||||
"packages/dbgpt-accelerator/*"
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
@ -65,4 +71,4 @@ select = ["E", "F", "I"]
|
||||
|
||||
[tool.ruff.isort]
|
||||
# Specify the local modules (first-party)
|
||||
known-first-party = ["dbgpt", "dbgpt_accelerator", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]
|
||||
known-first-party = ["dbgpt", "dbgpt_acc_auto", "dbgpt_client", "dbgpt_ext", "dbgpt_serve", "dbgpt_app"]
|
||||
|
Loading…
Reference in New Issue
Block a user